diff --git a/go/tools/asthelpergen/asthelpergen.go b/go/tools/asthelpergen/asthelpergen.go index 1811ff72511..3f59fdb3ece 100644 --- a/go/tools/asthelpergen/asthelpergen.go +++ b/go/tools/asthelpergen/asthelpergen.go @@ -29,7 +29,6 @@ import ( "golang.org/x/tools/go/packages" "vitess.io/vitess/go/textutil" - "vitess.io/vitess/go/tools/codegen" ) diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 1ff48b8be78..b510c81767c 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -365,8 +365,8 @@ type ( With *With Ignore Ignore Comments *ParsedComments - Targets TableNames TableExprs TableExprs + Targets TableNames Partitions Partitions Where *Where OrderBy OrderBy diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index b29b4c90047..912cba84e6c 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -1175,8 +1175,8 @@ func CloneRefOfDelete(n *Delete) *Delete { out := *n out.With = CloneRefOfWith(n.With) out.Comments = CloneRefOfParsedComments(n.Comments) - out.Targets = CloneTableNames(n.Targets) out.TableExprs = CloneTableExprs(n.TableExprs) + out.Targets = CloneTableNames(n.Targets) out.Partitions = ClonePartitions(n.Partitions) out.Where = CloneRefOfWhere(n.Where) out.OrderBy = CloneOrderBy(n.OrderBy) diff --git a/go/vt/sqlparser/ast_copy_on_rewrite.go b/go/vt/sqlparser/ast_copy_on_rewrite.go index 86dda29ebcf..65fab00c890 100644 --- a/go/vt/sqlparser/ast_copy_on_rewrite.go +++ b/go/vt/sqlparser/ast_copy_on_rewrite.go @@ -1850,18 +1850,18 @@ func (c *cow) copyOnRewriteRefOfDelete(n *Delete, parent SQLNode) (out SQLNode, if c.pre == nil || c.pre(n, parent) { _With, changedWith := c.copyOnRewriteRefOfWith(n.With, n) _Comments, changedComments := c.copyOnRewriteRefOfParsedComments(n.Comments, n) - _Targets, changedTargets := c.copyOnRewriteTableNames(n.Targets, n) _TableExprs, changedTableExprs := c.copyOnRewriteTableExprs(n.TableExprs, n) + _Targets, changedTargets := c.copyOnRewriteTableNames(n.Targets, n) _Partitions, changedPartitions := c.copyOnRewritePartitions(n.Partitions, n) _Where, changedWhere := c.copyOnRewriteRefOfWhere(n.Where, n) _OrderBy, changedOrderBy := c.copyOnRewriteOrderBy(n.OrderBy, n) _Limit, changedLimit := c.copyOnRewriteRefOfLimit(n.Limit, n) - if changedWith || changedComments || changedTargets || changedTableExprs || changedPartitions || changedWhere || changedOrderBy || changedLimit { + if changedWith || changedComments || changedTableExprs || changedTargets || changedPartitions || changedWhere || changedOrderBy || changedLimit { res := *n res.With, _ = _With.(*With) res.Comments, _ = _Comments.(*ParsedComments) - res.Targets, _ = _Targets.(TableNames) res.TableExprs, _ = _TableExprs.(TableExprs) + res.Targets, _ = _Targets.(TableNames) res.Partitions, _ = _Partitions.(Partitions) res.Where, _ = _Where.(*Where) res.OrderBy, _ = _OrderBy.(OrderBy) diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 9beed3a8242..0ded1081fc3 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -2362,8 +2362,8 @@ func (cmp *Comparator) RefOfDelete(a, b *Delete) bool { return cmp.RefOfWith(a.With, b.With) && a.Ignore == b.Ignore && cmp.RefOfParsedComments(a.Comments, b.Comments) && - cmp.TableNames(a.Targets, b.Targets) && cmp.TableExprs(a.TableExprs, b.TableExprs) && + cmp.TableNames(a.Targets, b.Targets) && cmp.Partitions(a.Partitions, b.Partitions) && cmp.RefOfWhere(a.Where, b.Where) && cmp.OrderBy(a.OrderBy, b.OrderBy) && diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 863de56bfba..a61399ae8ae 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -172,7 +172,7 @@ func (node *Delete) Format(buf *TrackedBuffer) { if node.Ignore { buf.literal("ignore ") } - if node.Targets != nil { + if node.Targets != nil && !node.isSingleAliasExpr() { buf.astPrintf(node, "%v ", node.Targets) } buf.astPrintf(node, "from %v%v%v%v%v", node.TableExprs, node.Partitions, node.Where, node.OrderBy, node.Limit) diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 6f6f3594c18..37d3ddfa5b8 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -257,7 +257,7 @@ func (node *Delete) FormatFast(buf *TrackedBuffer) { if node.Ignore { buf.WriteString("ignore ") } - if node.Targets != nil { + if node.Targets != nil && !node.isSingleAliasExpr() { node.Targets.FormatFast(buf) buf.WriteByte(' ') } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 3e8b54f7e08..1de529c973b 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -24,13 +24,11 @@ import ( "strconv" "strings" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" - - "vitess.io/vitess/go/vt/log" - "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" ) // Walk calls postVisit on every node. @@ -2156,25 +2154,31 @@ func (s SelectExprs) AllAggregation() bool { return true } -// RemoveKeyspaceFromColName removes the Qualifier.Qualifier on all ColNames in the expression tree -func RemoveKeyspaceFromColName(expr Expr) { - RemoveKeyspace(expr) -} - // RemoveKeyspace removes the Qualifier.Qualifier on all ColNames in the AST func RemoveKeyspace(in SQLNode) { // Walk will only return an error if we return an error from the inner func. safe to ignore here _ = Walk(func(node SQLNode) (kontinue bool, err error) { - switch col := node.(type) { - case *ColName: - if col.Qualifier.Qualifier.NotEmpty() { - col.Qualifier.Qualifier = NewIdentifierCS("") - } + if col, ok := node.(*ColName); ok && col.Qualifier.Qualifier.NotEmpty() { + col.Qualifier.Qualifier = NewIdentifierCS("") } + return true, nil }, in) } +// RemoveKeyspaceInTables removes the Qualifier on all TableNames in the AST +func RemoveKeyspaceInTables(in SQLNode) { + // Walk will only return an error if we return an error from the inner func. safe to ignore here + Rewrite(in, nil, func(cursor *Cursor) bool { + if tbl, ok := cursor.Node().(TableName); ok && tbl.Qualifier.NotEmpty() { + tbl.Qualifier = NewIdentifierCS("") + cursor.Replace(tbl) + } + + return true + }) +} + func convertStringToInt(integer string) int { val, _ := strconv.Atoi(integer) return val @@ -2536,3 +2540,14 @@ func IsLiteral(expr Expr) bool { func (ct *ColumnType) Invisible() bool { return ct.Options.Invisible != nil && *ct.Options.Invisible } + +func (node *Delete) isSingleAliasExpr() bool { + if len(node.Targets) > 1 { + return false + } + if len(node.TableExprs) != 1 { + return false + } + _, isAliasExpr := node.TableExprs[0].(*AliasedTableExpr) + return isAliasExpr +} diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 0121695fe8c..6ec89e9a2ba 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -2455,13 +2455,13 @@ func (a *application) rewriteRefOfDelete(parent SQLNode, node *Delete, replacer }) { return false } - if !a.rewriteTableNames(node, node.Targets, func(newNode, parent SQLNode) { - parent.(*Delete).Targets = newNode.(TableNames) + if !a.rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { + parent.(*Delete).TableExprs = newNode.(TableExprs) }) { return false } - if !a.rewriteTableExprs(node, node.TableExprs, func(newNode, parent SQLNode) { - parent.(*Delete).TableExprs = newNode.(TableExprs) + if !a.rewriteTableNames(node, node.Targets, func(newNode, parent SQLNode) { + parent.(*Delete).Targets = newNode.(TableNames) }) { return false } diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index a88d689f102..bb2ec7c3500 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -1377,10 +1377,10 @@ func VisitRefOfDelete(in *Delete, f Visit) error { if err := VisitRefOfParsedComments(in.Comments, f); err != nil { return err } - if err := VisitTableNames(in.Targets, f); err != nil { + if err := VisitTableExprs(in.TableExprs, f); err != nil { return err } - if err := VisitTableExprs(in.TableExprs, f); err != nil { + if err := VisitTableNames(in.Targets, f); err != nil { return err } if err := VisitPartitions(in.Partitions, f); err != nil { diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index a31b5767baa..ebac6a68e23 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -1106,13 +1106,6 @@ func (cached *Delete) CachedSize(alloc bool) int64 { size += cached.With.CachedSize(true) // field Comments *vitess.io/vitess/go/vt/sqlparser.ParsedComments size += cached.Comments.CachedSize(true) - // field Targets vitess.io/vitess/go/vt/sqlparser.TableNames - { - size += hack.RuntimeAllocSize(int64(cap(cached.Targets)) * int64(32)) - for _, elem := range cached.Targets { - size += elem.CachedSize(false) - } - } // field TableExprs vitess.io/vitess/go/vt/sqlparser.TableExprs { size += hack.RuntimeAllocSize(int64(cap(cached.TableExprs)) * int64(16)) @@ -1122,6 +1115,13 @@ func (cached *Delete) CachedSize(alloc bool) int64 { } } } + // field Targets vitess.io/vitess/go/vt/sqlparser.TableNames + { + size += hack.RuntimeAllocSize(int64(cap(cached.Targets)) * int64(32)) + for _, elem := range cached.Targets { + size += elem.CachedSize(false) + } + } // field Partitions vitess.io/vitess/go/vt/sqlparser.Partitions { size += hack.RuntimeAllocSize(int64(cap(cached.Partitions)) * int64(32)) diff --git a/go/vt/sqlparser/parse_test.go b/go/vt/sqlparser/parse_test.go index b80ded73b0b..bb51bbb2479 100644 --- a/go/vt/sqlparser/parse_test.go +++ b/go/vt/sqlparser/parse_test.go @@ -1360,7 +1360,7 @@ var ( input: "delete /* limit */ from a limit b", }, { input: "delete /* alias where */ t.* from a as t where t.id = 2", - output: "delete /* alias where */ t from a as t where t.id = 2", + output: "delete /* alias where */ from a as t where t.id = 2", }, { input: "delete t.* from t, t1", output: "delete t from t, t1", diff --git a/go/vt/vtgate/planbuilder/delete.go b/go/vt/vtgate/planbuilder/delete.go index e8b71ea9a0e..059c663465d 100644 --- a/go/vt/vtgate/planbuilder/delete.go +++ b/go/vt/vtgate/planbuilder/delete.go @@ -144,16 +144,7 @@ func checkIfDeleteSupported(del *sqlparser.Delete, semTable *semantics.SemTable) return semTable.NotUnshardedErr } - // Delete is only supported for a single TableExpr which is supposed to be an aliased expression - multiShardErr := vterrors.VT12001("multi-shard or vindex write statement") - if len(del.TableExprs) != 1 { - return multiShardErr - } - _, isAliasedExpr := del.TableExprs[0].(*sqlparser.AliasedTableExpr) - if !isAliasedExpr { - return multiShardErr - } - + // Delete is only supported for single Target. if len(del.Targets) > 1 { return vterrors.VT12001("multi-table DELETE statement in a sharded keyspace") } diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index abbda050b0c..65012e68e02 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -527,7 +527,7 @@ func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) ( case *sqlparser.Update: return buildUpdateLogicalPlan(ctx, op, dmlOp, stmt, hints) case *sqlparser.Delete: - return buildDeleteLogicalPlan(ctx, op, dmlOp, hints) + return buildDeleteLogicalPlan(ctx, op, dmlOp, stmt, hints) case *sqlparser.Insert: return buildInsertLogicalPlan(op, dmlOp, stmt, hints) default: @@ -689,24 +689,20 @@ func buildUpdateLogicalPlan( return &primitiveWrapper{prim: e}, nil } -func buildDeleteLogicalPlan( - ctx *plancontext.PlanningContext, - rb *operators.Route, - dmlOp operators.Operator, - hints *queryHints, -) (logicalPlan, error) { +func buildDeleteLogicalPlan(ctx *plancontext.PlanningContext, rb *operators.Route, dmlOp operators.Operator, stmt *sqlparser.Delete, hints *queryHints) (logicalPlan, error) { del := dmlOp.(*operators.Delete) rp := newRoutingParams(ctx, rb.Routing.OpCode()) rb.Routing.UpdateRoutingParams(ctx, rp) + vtable := del.Target.VTable edml := &engine.DML{ - Query: generateQuery(del.AST), - TableNames: []string{del.VTable.Name.String()}, - Vindexes: del.VTable.Owned, + Query: generateQuery(stmt), + TableNames: []string{vtable.Name.String()}, + Vindexes: vtable.Owned, OwnedVindexQuery: del.OwnedVindexQuery, RoutingParameters: rp, } - transformDMLPlan(del.VTable, edml, rb.Routing, del.OwnedVindexQuery != "") + transformDMLPlan(vtable, edml, rb.Routing, del.OwnedVindexQuery != "") e := &engine.Delete{ DML: edml, diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 961a7d252ff..1a9ef3c77c1 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -309,7 +309,7 @@ func (ts *tableSorter) Swap(i, j int) { func removeKeyspaceFromSelectExpr(expr sqlparser.SelectExpr) { switch expr := expr.(type) { case *sqlparser.AliasedExpr: - sqlparser.RemoveKeyspaceFromColName(expr.Expr) + sqlparser.RemoveKeyspace(expr.Expr) case *sqlparser.StarExpr: expr.TableName.Qualifier = sqlparser.NewIdentifierCS("") } @@ -376,7 +376,7 @@ func buildQuery(op Operator, qb *queryBuilder) { case *Update: buildUpdate(op, qb) case *Delete: - buildDML(op, qb) + buildDelete(op, qb) case *Insert: buildDML(op, qb) default: @@ -384,6 +384,28 @@ func buildQuery(op Operator, qb *queryBuilder) { } } +func buildDelete(op *Delete, qb *queryBuilder) { + buildQuery(op.Source, qb) + // currently the qb builds a select query underneath. + // Will take the `From` and `Where` from this select + // and create a delete statement. + // TODO: change it to directly produce `delete` statement. + sel, ok := qb.stmt.(*sqlparser.Select) + if !ok { + panic(vterrors.VT13001("expected a select here")) + } + + qb.dmlOperator = op + qb.stmt = &sqlparser.Delete{ + Ignore: sqlparser.Ignore(op.Ignore), + Targets: sqlparser.TableNames{op.Target.Name}, + TableExprs: sel.From, + Where: sel.Where, + OrderBy: op.OrderBy, + Limit: op.Limit, + } +} + func buildUpdate(op *Update, qb *queryBuilder) { tblName := sqlparser.NewTableName(op.QTable.Table.Name.String()) aTblExpr := &sqlparser.AliasedTableExpr{ diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 63dec0c84a8..7a4758493b2 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -72,7 +72,7 @@ func addWherePredicates(ctx *plancontext.PlanningContext, expr sqlparser.Expr, o outerID := TableID(op) exprs := sqlparser.SplitAndExpression(nil, expr) for _, expr := range exprs { - sqlparser.RemoveKeyspaceFromColName(expr) + sqlparser.RemoveKeyspace(expr) subq := sqc.handleSubquery(ctx, expr, outerID) if subq != nil { continue diff --git a/go/vt/vtgate/planbuilder/operators/delete.go b/go/vt/vtgate/planbuilder/operators/delete.go index 17f6125992f..ffa851fdecb 100644 --- a/go/vt/vtgate/planbuilder/operators/delete.go +++ b/go/vt/vtgate/planbuilder/operators/delete.go @@ -21,43 +21,54 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/vtgate/vindexes" ) type Delete struct { - QTable *QueryTable - VTable *vindexes.Table + Target TargetTable OwnedVindexQuery string - AST *sqlparser.Delete + OrderBy sqlparser.OrderBy + Limit *sqlparser.Limit + Ignore bool + Source Operator - noInputs noColumns noPredicates } +type TargetTable struct { + ID semantics.TableSet + VTable *vindexes.Table + Name sqlparser.TableName +} + // Introduces implements the PhysicalOperator interface func (d *Delete) introducesTableID() semantics.TableSet { - return d.QTable.ID + return d.Target.ID } // Clone implements the Operator interface -func (d *Delete) Clone([]Operator) Operator { - return &Delete{ - QTable: d.QTable, - VTable: d.VTable, - OwnedVindexQuery: d.OwnedVindexQuery, - AST: d.AST, +func (d *Delete) Clone(inputs []Operator) Operator { + newD := *d + newD.SetInputs(inputs) + return &newD +} + +func (d *Delete) Inputs() []Operator { + return []Operator{d.Source} +} + +func (d *Delete) SetInputs(inputs []Operator) { + if len(inputs) != 1 { + panic(vterrors.VT13001("unexpected number of inputs to Delete operator")) } + d.Source = inputs[0] } func (d *Delete) TablesUsed() []string { - if d.VTable != nil { - return SingleQualifiedIdentifier(d.VTable.Keyspace, d.VTable.Name) - } - return nil + return SingleQualifiedIdentifier(d.Target.VTable.Keyspace, d.Target.VTable.Name) } func (d *Delete) GetOrdering(*plancontext.PlanningContext) []OrderBy { @@ -65,20 +76,23 @@ func (d *Delete) GetOrdering(*plancontext.PlanningContext) []OrderBy { } func (d *Delete) ShortDescription() string { - return fmt.Sprintf("%s.%s %s", d.VTable.Keyspace.Name, d.VTable.Name.String(), sqlparser.String(d.AST.Where)) -} + limit := "" + orderBy := "" + if d.Limit != nil { + limit = " " + sqlparser.String(d.Limit) + } + if len(d.OrderBy) > 0 { + orderBy = " " + sqlparser.String(d.OrderBy) + } -func (d *Delete) Statement() sqlparser.Statement { - return d.AST + return fmt.Sprintf("%s.%s%s%s", d.Target.VTable.Keyspace.Name, d.Target.VTable.Name.String(), orderBy, limit) } func createOperatorFromDelete(ctx *plancontext.PlanningContext, deleteStmt *sqlparser.Delete) Operator { - tableInfo, qt := createQueryTableForDML(ctx, deleteStmt.TableExprs[0], deleteStmt.Where) - vindexTable, routing := buildVindexTableForDML(ctx, tableInfo, qt, "delete") - delClone := sqlparser.CloneRefOfDelete(deleteStmt) - // Create the delete operator first. - delOp := createDeleteOperator(ctx, deleteStmt, qt, vindexTable, routing) + + delOp := createDeleteOperator(ctx, deleteStmt) + if deleteStmt.Comments != nil { delOp = &LockAndComment{ Source: delOp, @@ -92,64 +106,91 @@ func createOperatorFromDelete(ctx *plancontext.PlanningContext, deleteStmt *sqlp return delOp } // If the delete statement has a limit, we don't support it yet. - if deleteStmt.Limit != nil { + if delClone.Limit != nil { panic(vterrors.VT12001("foreign keys management at vitess with limit")) } return createFkCascadeOpForDelete(ctx, delOp, delClone, childFks) } -func createDeleteOperator( - ctx *plancontext.PlanningContext, - deleteStmt *sqlparser.Delete, - qt *QueryTable, - vindexTable *vindexes.Table, - routing Routing) Operator { - del := &Delete{ - QTable: qt, - VTable: vindexTable, - AST: deleteStmt, - } - route := &Route{ - Source: del, - Routing: routing, - } +func createDeleteOperator(ctx *plancontext.PlanningContext, del *sqlparser.Delete) Operator { + op := crossJoin(ctx, del.TableExprs) - if !vindexTable.Keyspace.Sharded { - return route + if del.Where != nil { + op = addWherePredicates(ctx, del.Where.Expr, op) } - primaryVindex, vindexAndPredicates := getVindexInformation(qt.ID, vindexTable) - - tr, ok := routing.(*ShardedRouting) - if ok { - tr.VindexPreds = vindexAndPredicates + target := del.Targets[0] + tblID, exists := ctx.SemTable.Targets[target.Name] + if !exists { + panic(vterrors.VT13001("delete target table should be part of semantic analyzer")) } - - var ovq string - if len(vindexTable.Owned) > 0 { - tblExpr := &sqlparser.AliasedTableExpr{Expr: sqlparser.TableName{Name: vindexTable.Name}, As: qt.Alias.As} - ovq = generateOwnedVindexQuery(tblExpr, deleteStmt, vindexTable, primaryVindex.Columns) + tblInfo, err := ctx.SemTable.TableInfoFor(tblID) + if err != nil { + panic(err) } - del.OwnedVindexQuery = ovq + vTbl := tblInfo.GetVindexTable() + // Reference table should delete from the source table. + if vTbl.Type == vindexes.TypeReference && vTbl.Source != nil { + vTbl = updateQueryGraphWithSource(ctx, op, tblID, vTbl) + } - sqc := &SubQueryBuilder{} - for _, predicate := range qt.Predicates { - subq := sqc.handleSubquery(ctx, predicate, qt.ID) - if subq != nil { - continue + var ovq string + if vTbl.Keyspace.Sharded && vTbl.Type == vindexes.TypeTable { + primaryVindex, _ := getVindexInformation(tblID, vTbl) + ate := tblInfo.GetAliasedTableExpr() + if len(vTbl.Owned) > 0 { + ovq = generateOwnedVindexQuery(ate, del, vTbl, primaryVindex.Columns) } + } - routing = UpdateRoutingLogic(ctx, predicate, routing) + name, err := tblInfo.Name() + if err != nil { + panic(err) } - if routing.OpCode() == engine.Scatter && deleteStmt.Limit != nil { - // TODO systay: we should probably check for other op code types - IN could also hit multiple shards (2022-04-07) - panic(vterrors.VT12001("multi shard DELETE with LIMIT")) + return &Delete{ + Target: TargetTable{ + ID: tblID, + VTable: vTbl, + Name: name, + }, + Source: op, + Ignore: bool(del.Ignore), + Limit: del.Limit, + OrderBy: del.OrderBy, + OwnedVindexQuery: ovq, } +} - return sqc.getRootOperator(route, nil) +func updateQueryGraphWithSource(ctx *plancontext.PlanningContext, input Operator, tblID semantics.TableSet, vTbl *vindexes.Table) *vindexes.Table { + sourceTable, _, _, _, _, err := ctx.VSchema.FindTableOrVindex(vTbl.Source.TableName) + if err != nil { + panic(err) + } + vTbl = sourceTable + TopDown(input, TableID, func(op Operator, lhsTables semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { + qg, ok := op.(*QueryGraph) + if !ok { + return op, NoRewrite + } + if len(qg.Tables) > 1 { + panic(vterrors.VT12001("DELETE on reference table with join")) + } + for _, tbl := range qg.Tables { + if tbl.ID != tblID { + continue + } + tbl.Alias = sqlparser.NewAliasedTableExpr(sqlparser.NewTableName(vTbl.Name.String()), tbl.Alias.As.String()) + tbl.Table, _ = tbl.Alias.TableName() + } + return op, Rewrote("change query table point to source table") + }, func(operator Operator) VisitRule { + _, ok := operator.(*QueryGraph) + return VisitRule(ok) + }) + return vTbl } func createFkCascadeOpForDelete(ctx *plancontext.PlanningContext, parentOp Operator, delStmt *sqlparser.Delete, childFks []vindexes.ChildFKInfo) Operator { diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index 35bf26f9793..42ec1b75562 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -92,7 +92,7 @@ func createOuterJoin(tableExpr *sqlparser.JoinTableExpr, lhs, rhs Operator) Oper panic(vterrors.VT12001("subquery in outer join predicate")) } predicate := tableExpr.Condition.On - sqlparser.RemoveKeyspaceFromColName(predicate) + sqlparser.RemoveKeyspace(predicate) return &Join{LHS: lhs, RHS: rhs, LeftJoin: true, Predicate: predicate} } @@ -115,7 +115,7 @@ func createInnerJoin(ctx *plancontext.PlanningContext, tableExpr *sqlparser.Join sqc := &SubQueryBuilder{} outerID := TableID(op) joinPredicate := tableExpr.Condition.On - sqlparser.RemoveKeyspaceFromColName(joinPredicate) + sqlparser.RemoveKeyspace(joinPredicate) exprs := sqlparser.SplitAndExpression(nil, joinPredicate) for _, pred := range exprs { subq := sqc.handleSubquery(ctx, pred, outerID) diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 19f6f3bf27d..b2d51c2935e 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -21,6 +21,7 @@ import ( "io" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -96,6 +97,8 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { return optimizeQueryGraph(ctx, in) case *LockAndComment: return pushLockAndComment(in) + case *Delete: + return tryPushDelete(in) default: return in, NoRewrite } @@ -104,6 +107,32 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { return FixedPointBottomUp(root, TableID, visitor, stopAtRoute) } +func tryPushDelete(in *Delete) (Operator, *ApplyResult) { + switch src := in.Source.(type) { + case *Route: + if in.Limit != nil && !src.IsSingleShardOrByDestination() { + panic(vterrors.VT12001("multi shard DELETE with LIMIT")) + } + + switch r := src.Routing.(type) { + case *SequenceRouting: + // Sequences are just unsharded routes + src.Routing = &AnyShardRouting{ + keyspace: r.keyspace, + } + case *AnyShardRouting: + // References would have an unsharded source + // Alternates are not required. + r.Alternates = nil + } + return Swap(in, src, "pushed delete under route") + case *ApplyJoin: + panic(vterrors.VT12001("multi shard DELETE with join table references")) + } + + return in, nil +} + func pushLockAndComment(l *LockAndComment) (Operator, *ApplyResult) { switch src := l.Source.(type) { case *Horizon, *QueryGraph: diff --git a/go/vt/vtgate/planbuilder/operators/route.go b/go/vt/vtgate/planbuilder/operators/route.go index c540ad6791d..952e455abb0 100644 --- a/go/vt/vtgate/planbuilder/operators/route.go +++ b/go/vt/vtgate/planbuilder/operators/route.go @@ -19,8 +19,11 @@ package operators import ( "fmt" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" @@ -278,6 +281,14 @@ func (r *Route) IsSingleShard() bool { return false } +func (r *Route) IsSingleShardOrByDestination() bool { + switch r.Routing.OpCode() { + case engine.Unsharded, engine.DBA, engine.Next, engine.EqualUnique, engine.Reference, engine.ByDestination: + return true + } + return false +} + func tupleAccess(expr sqlparser.Expr, coordinates []int) sqlparser.Expr { tuple, _ := expr.(sqlparser.ValTuple) for _, idx := range coordinates { @@ -375,23 +386,55 @@ func findVSchemaTableAndCreateRoute( solves semantics.TableSet, planAlternates bool, ) *Route { - vschemaTable, _, _, _, target, err := ctx.VSchema.FindTableOrVindex(tableName) - if target != nil { - panic(vterrors.VT09017("SELECT with a target destination is not allowed")) - } + vschemaTable, _, _, tabletType, target, err := ctx.VSchema.FindTableOrVindex(tableName) if err != nil { panic(err) } + targeted := createTargetedRouting(ctx, target, tabletType, vschemaTable) + return createRouteFromVSchemaTable( ctx, queryTable, vschemaTable, solves, planAlternates, + targeted, ) } +func createTargetedRouting(ctx *plancontext.PlanningContext, target key.Destination, tabletType topodatapb.TabletType, vschemaTable *vindexes.Table) Routing { + switch ctx.Statement.(type) { + case *sqlparser.Update: + if tabletType != topodatapb.TabletType_PRIMARY { + panic(vterrors.VT09002("update")) + } + case *sqlparser.Delete: + if tabletType != topodatapb.TabletType_PRIMARY { + panic(vterrors.VT09002("delete")) + } + case *sqlparser.Insert: + if tabletType != topodatapb.TabletType_PRIMARY { + panic(vterrors.VT09002("insert")) + } + if target != nil { + panic(vterrors.VT09017("INSERT with a target destination is not allowed")) + } + case sqlparser.SelectStatement: + if target != nil { + panic(vterrors.VT09017("SELECT with a target destination is not allowed")) + } + } + + if target != nil { + return &TargetedRouting{ + keyspace: vschemaTable.Keyspace, + TargetDestination: target, + } + } + return nil +} + // createRouteFromTable creates a route from the given VSchema table. func createRouteFromVSchemaTable( ctx *plancontext.PlanningContext, @@ -399,6 +442,7 @@ func createRouteFromVSchemaTable( vschemaTable *vindexes.Table, solves semantics.TableSet, planAlternates bool, + targeted Routing, ) *Route { if vschemaTable.Name.String() != queryTable.Table.Name.String() { // we are dealing with a routed table @@ -423,8 +467,14 @@ func createRouteFromVSchemaTable( }, } - // We create the appropiate Routing struct here, depending on the type of table we are dealing with. - routing := createRoutingForVTable(vschemaTable, solves) + // We create the appropriate Routing struct here, depending on the type of table we are dealing with. + var routing Routing + if targeted != nil { + routing = targeted + } else { + routing = createRoutingForVTable(vschemaTable, solves) + } + for _, predicate := range queryTable.Predicates { routing = UpdateRoutingLogic(ctx, predicate, routing) } diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index cb33f4e1f55..07dbab3bc90 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -138,6 +138,7 @@ func generateOwnedVindexQuery(tblExpr sqlparser.TableExpr, del *sqlparser.Delete buf.Myprintf(", %v", column) } } + sqlparser.RemoveKeyspaceInTables(tblExpr) buf.Myprintf(" from %v%v%v%v for update", tblExpr, del.Where, del.OrderBy, del.Limit) return buf.String() } diff --git a/go/vt/vtgate/planbuilder/operators/subquery_builder.go b/go/vt/vtgate/planbuilder/operators/subquery_builder.go index 6540ed10701..e582295ba91 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_builder.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_builder.go @@ -201,7 +201,7 @@ func (sqb *SubQueryBuilder) inspectWhere( outerID: sqb.outerID, } for _, predicate := range sqlparser.SplitAndExpression(nil, in.Expr) { - sqlparser.RemoveKeyspaceFromColName(predicate) + sqlparser.RemoveKeyspace(predicate) subq := sqb.handleSubquery(ctx, predicate, sqb.totalID) if subq != nil { continue diff --git a/go/vt/vtgate/planbuilder/operators/table.go b/go/vt/vtgate/planbuilder/operators/table.go index 93b406232b2..bf03243bb81 100644 --- a/go/vt/vtgate/planbuilder/operators/table.go +++ b/go/vt/vtgate/planbuilder/operators/table.go @@ -115,7 +115,7 @@ func addColumn(ctx *plancontext.PlanningContext, op ColNameColumns, e sqlparser. if !ok { panic(vterrors.VT09018(fmt.Sprintf("cannot add '%s' expression to a table/vindex", sqlparser.String(e)))) } - sqlparser.RemoveKeyspaceFromColName(col) + sqlparser.RemoveKeyspace(col) cols := op.GetColNames() colAsExpr := func(c *sqlparser.ColName) sqlparser.Expr { return c } if offset, found := canReuseColumn(ctx, cols, e, colAsExpr); found { diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index e53ce5d5885..3871c8fdbc4 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -49,6 +49,9 @@ type PlanningContext struct { // CurrentPhase keeps track of how far we've gone in the planning process // The type should be operators.Phase, but depending on that would lead to circular dependencies CurrentPhase int + + // Statement contains the originally parsed statement + Statement sqlparser.Statement } func CreatePlanningContext(stmt sqlparser.Statement, @@ -77,6 +80,7 @@ func CreatePlanningContext(stmt sqlparser.Statement, SkipPredicates: map[sqlparser.Expr]any{}, PlannerVersion: version, ReservedArguments: map[sqlparser.Expr]string{}, + Statement: stmt, }, nil } diff --git a/go/vt/vtgate/planbuilder/single_sharded_shortcut.go b/go/vt/vtgate/planbuilder/single_sharded_shortcut.go index daf19ced859..e3999c0703d 100644 --- a/go/vt/vtgate/planbuilder/single_sharded_shortcut.go +++ b/go/vt/vtgate/planbuilder/single_sharded_shortcut.go @@ -20,11 +20,10 @@ import ( "sort" "strings" - "vitess.io/vitess/go/vt/vtgate/planbuilder/operators" - "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" - "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -106,7 +105,7 @@ func getTableNames(semTable *semantics.SemTable) ([]sqlparser.TableName, error) func removeKeyspaceFromSelectExpr(expr sqlparser.SelectExpr) { switch expr := expr.(type) { case *sqlparser.AliasedExpr: - sqlparser.RemoveKeyspaceFromColName(expr.Expr) + sqlparser.RemoveKeyspace(expr.Expr) case *sqlparser.StarExpr: expr.TableName.Qualifier = sqlparser.NewIdentifierCS("") } diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index eb257064afd..5ca6b034d24 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -4926,6 +4926,126 @@ "user.music" ] } + }, + { + "comment": "delete from reference table - query send to source table", + "query": "delete from user.ref_with_source where col = 1", + "plan": { + "QueryType": "DELETE", + "Original": "delete from user.ref_with_source where col = 1", + "Instructions": { + "OperatorType": "Delete", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "delete from source_of_ref where col = 1", + "Table": "source_of_ref" + }, + "TablesUsed": [ + "main.source_of_ref" + ] + } + }, + { + "comment": "delete from reference table - no source", + "query": "delete from user.ref", + "plan": { + "QueryType": "DELETE", + "Original": "delete from user.ref", + "Instructions": { + "OperatorType": "Delete", + "Variant": "Reference", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "Query": "delete from ref", + "Table": "ref" + }, + "TablesUsed": [ + "user.ref" + ] + } + }, + { + "comment": "delete by target destination with limit", + "query": "delete from `user[-]`.`user` limit 20", + "plan": { + "QueryType": "DELETE", + "Original": "delete from `user[-]`.`user` limit 20", + "Instructions": { + "OperatorType": "Delete", + "Variant": "ByDestination", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "KsidLength": 1, + "KsidVindex": "user_index", + "OwnedVindexQuery": "select Id, `Name`, Costly from `user` limit 20 for update", + "Query": "delete from `user` limit 20", + "Table": "user" + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "delete sharded table with join with reference table", + "query": "delete u from user u join ref_with_source r on u.col = r.col", + "plan": { + "QueryType": "DELETE", + "Original": "delete u from user u join ref_with_source r on u.col = r.col", + "Instructions": { + "OperatorType": "Delete", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "KsidLength": 1, + "KsidVindex": "user_index", + "OwnedVindexQuery": "select Id, `Name`, Costly from `user` as u for update", + "Query": "delete u from `user` as u, ref_with_source as r where u.col = r.col", + "Table": "user" + }, + "TablesUsed": [ + "user.ref_with_source", + "user.user" + ] + } + }, + { + "comment": "delete sharded table with join with another sharded table on vindex column", + "query": "delete u from user u join music m on u.id = m.user_id", + "plan": { + "QueryType": "DELETE", + "Original": "delete u from user u join music m on u.id = m.user_id", + "Instructions": { + "OperatorType": "Delete", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "KsidLength": 1, + "KsidVindex": "user_index", + "OwnedVindexQuery": "select Id, `Name`, Costly from `user` as u for update", + "Query": "delete u from `user` as u, music as m where u.id = m.user_id", + "Table": "user" + }, + "TablesUsed": [ + "user.music", + "user.user" + ] + } } - ] diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 8c4b7c89e44..b1c1c45001c 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -54,11 +54,6 @@ "query": "update user_extra set val = 1 where (name = 'foo' or id = 1) limit 1", "plan": "VT12001: unsupported: multi shard UPDATE with LIMIT" }, - { - "comment": "multi delete multi table", - "query": "delete user from user join user_extra on user.id = user_extra.id where user.name = 'foo'", - "plan": "VT12001: unsupported: multi-shard or vindex write statement" - }, { "comment": "update changes primary vindex column", "query": "update user set id = 1 where id = 1", @@ -162,7 +157,7 @@ { "comment": "delete with multi-table targets", "query": "delete music,user from music inner join user where music.id = user.id", - "plan": "VT12001: unsupported: multi-shard or vindex write statement" + "plan": "VT12001: unsupported: multi-table DELETE statement in a sharded keyspace" }, { "comment": "select get_lock with non-dual table", @@ -388,5 +383,35 @@ "comment": "We need schema tracking to allow unexpanded columns inside UNION", "query": "select x from (select t.*, 0 as x from user t union select t.*, 1 as x from user_extra t) AS t", "plan": "VT09015: schema tracking required" + }, + { + "comment": "multi table delete with 2 sharded tables join on vindex column", + "query": "delete u, m from user u join music m on u.id = m.user_id", + "plan": "VT12001: unsupported: multi-table DELETE statement in a sharded keyspace" + }, + { + "comment": "multi table delete with 2 sharded tables join on non-vindex column", + "query": "delete u, m from user u join music m on u.col = m.col", + "plan": "VT12001: unsupported: multi-table DELETE statement in a sharded keyspace" + }, + { + "comment": "multi table delete with 1 sharded and 1 reference table", + "query": "delete u, r from user u join ref_with_source r on u.col = r.col", + "plan": "VT12001: unsupported: multi-table DELETE statement in a sharded keyspace" + }, + { + "comment": "multi delete multi table", + "query": "delete user from user join user_extra on user.id = user_extra.id where user.name = 'foo'", + "plan": "VT12001: unsupported: multi shard DELETE with join table references" + }, + { + "comment": "multi delete multi table with alias", + "query": "delete u from user u join music m on u.col = m.col", + "plan": "VT12001: unsupported: multi shard DELETE with join table references" + }, + { + "comment": "reference table delete with join", + "query": "delete r from user u join ref_with_source r on u.col = r.col", + "plan": "VT12001: unsupported: DELETE on reference table with join" } ] diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 4959045458f..17e9398f7f6 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -119,6 +119,7 @@ func (a *analyzer) newSemTable(statement sqlparser.Statement, coll collations.ID Direct: a.binder.direct, ExprTypes: a.typer.m, Tables: a.tables.Tables, + Targets: a.binder.targets, NotSingleRouteErr: a.projErr, NotUnshardedErr: a.unshardedErr, Warning: a.warning, diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index 8f0cc7a9704..e222cf619bd 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -275,6 +275,26 @@ func TestBindingMultiAliasedTableNegative(t *testing.T) { } } +func TestBindingDelete(t *testing.T) { + queries := []string{ + "delete tbl from tbl", + "delete from tbl", + "delete t1 from t1, t2", + } + for _, query := range queries { + t.Run(query, func(t *testing.T) { + stmt, semTable := parseAndAnalyze(t, query, "d") + del := stmt.(*sqlparser.Delete) + t1 := del.TableExprs[0].(*sqlparser.AliasedTableExpr) + ts := semTable.TableSetFor(t1) + assert.Equal(t, SingleTableSet(0), ts) + + actualTs := semTable.Targets[del.Targets[0].Name] + assert.Equal(t, ts, actualTs) + }) + } +} + func TestNotUniqueTableName(t *testing.T) { queries := []string{ "select * from t, t", diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index 27d059673cb..33422c3aa37 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -20,6 +20,8 @@ import ( "strings" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) // binder is responsible for finding all the column references in @@ -29,6 +31,7 @@ import ( type binder struct { recursive ExprDependencies direct ExprDependencies + targets map[sqlparser.IdentifierCS]TableSet scoper *scoper tc *tableCollector org originable @@ -44,6 +47,7 @@ func newBinder(scoper *scoper, org originable, tc *tableCollector, typer *typer) return &binder{ recursive: map[sqlparser.Expr]TableSet{}, direct: map[sqlparser.Expr]TableSet{}, + targets: map[sqlparser.IdentifierCS]TableSet{}, scoper: scoper, org: org, tc: tc, @@ -106,10 +110,47 @@ func (b *binder) up(cursor *sqlparser.Cursor) error { b.typer.m[ae.Expr] = t } } + case sqlparser.TableNames: + _, isDelete := cursor.Parent().(*sqlparser.Delete) + if !isDelete { + return nil + } + current := b.scoper.currentScope() + for _, target := range node { + finalDep, err := b.findDependentTableSet(current, target) + if err != nil { + return err + } + b.targets[target.Name] = finalDep.direct + } } return nil } +func (b *binder) findDependentTableSet(current *scope, target sqlparser.TableName) (dependency, error) { + var deps dependencies = ¬hing{} + for _, table := range current.tables { + tblName, err := table.Name() + if err != nil { + continue + } + if tblName.Name.String() != target.Name.String() { + continue + } + ts := b.org.tableSetFor(table.GetAliasedTableExpr()) + c := createCertain(ts, ts, evalengine.Type{}) + deps = deps.merge(c, false) + } + finalDep, err := deps.get() + if err != nil { + return dependency{}, err + } + if finalDep.direct != finalDep.recursive { + return dependency{}, vterrors.VT03004(target.Name.String()) + } + return finalDep, nil +} + func (b *binder) bindCountStar(node *sqlparser.CountStar) { scope := b.scoper.currentScope() var ts TableSet diff --git a/go/vt/vtgate/semantics/derived_table.go b/go/vt/vtgate/semantics/derived_table.go index fd649436ab0..0425d78ed93 100644 --- a/go/vt/vtgate/semantics/derived_table.go +++ b/go/vt/vtgate/semantics/derived_table.go @@ -141,7 +141,7 @@ func (dt *DerivedTable) Name() (sqlparser.TableName, error) { return dt.ASTNode.TableName() } -func (dt *DerivedTable) getAliasedTableExpr() *sqlparser.AliasedTableExpr { +func (dt *DerivedTable) GetAliasedTableExpr() *sqlparser.AliasedTableExpr { return dt.ASTNode } diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index c71941afdd5..3c1235dd376 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -64,6 +64,21 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { return r.handleWith(node) case *sqlparser.AliasedTableExpr: return r.handleAliasedTable(node) + case *sqlparser.Delete: + // When we do not have any target, it is a single table delete. + // In a single table delete, the table references is always a single aliased table expression. + if len(node.Targets) != 0 { + return nil + } + tblExpr, ok := node.TableExprs[0].(*sqlparser.AliasedTableExpr) + if !ok { + return nil + } + tblName, err := tblExpr.TableName() + if err != nil { + return err + } + node.Targets = append(node.Targets, tblName) } return nil } diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index 09ddb223eef..476f993f3d7 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -573,3 +573,33 @@ func TestCTEToDerivedTableRewrite(t *testing.T) { }) } } + +// TestDeleteTargetTableRewrite checks that delete target rewrite is done correctly. +func TestDeleteTargetTableRewrite(t *testing.T) { + cDB := "db" + tcases := []struct { + sql string + target string + }{{ + sql: "delete from t", + target: "t", + }, { + sql: "delete from t t1", + target: "t1", + }, { + sql: "delete t2 from t t1, t t2", + target: "t2", + }, { + sql: "delete t2,t1 from t t1, t t2", + target: "t2, t1", + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.NewTestParser().Parse(tcase.sql) + require.NoError(t, err) + _, err = Analyze(ast, cDB, fakeSchemaInfo()) + require.NoError(t, err) + require.Equal(t, tcase.target, sqlparser.String(ast.(*sqlparser.Delete).Targets)) + }) + } +} diff --git a/go/vt/vtgate/semantics/real_table.go b/go/vt/vtgate/semantics/real_table.go index 7aafb697698..72549b98e8c 100644 --- a/go/vt/vtgate/semantics/real_table.go +++ b/go/vt/vtgate/semantics/real_table.go @@ -73,7 +73,7 @@ func (r *RealTable) getColumns() []ColumnInfo { } // GetExpr implements the TableInfo interface -func (r *RealTable) getAliasedTableExpr() *sqlparser.AliasedTableExpr { +func (r *RealTable) GetAliasedTableExpr() *sqlparser.AliasedTableExpr { return r.ASTNode } diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index c782da03678..878ac222911 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -197,7 +197,7 @@ func (s *scoper) up(cursor *sqlparser.Cursor) error { if isParentSelectStatement(cursor) { s.popScope() } - case *sqlparser.Select, sqlparser.GroupBy, *sqlparser.Update, *sqlparser.Delete, *sqlparser.Insert, *sqlparser.Union: + case *sqlparser.Select, sqlparser.GroupBy, *sqlparser.Update, *sqlparser.Insert, *sqlparser.Union: id := EmptyTableSet() for _, tableInfo := range s.currentScope().tables { set := tableInfo.getTableSet(s.org) diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index cf1ff7c2faf..7674a627b4e 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -50,7 +50,7 @@ type ( authoritative() bool // getAliasedTableExpr returns the AST struct behind this table - getAliasedTableExpr() *sqlparser.AliasedTableExpr + GetAliasedTableExpr() *sqlparser.AliasedTableExpr // canShortCut will return nil when the keyspace needs to be checked, // and a true/false if the decision has been made already @@ -117,6 +117,8 @@ type ( // It doesn't recurse inside derived tables to find the original dependencies. Direct ExprDependencies + Targets map[sqlparser.IdentifierCS]TableSet + // ColumnEqualities is used for transitive closures (e.g., if a == b and b == c, then a == c). ColumnEqualities map[columnName][]sqlparser.Expr @@ -517,7 +519,7 @@ func EmptySemTable() *SemTable { // TableSetFor returns the bitmask for this particular table func (st *SemTable) TableSetFor(t *sqlparser.AliasedTableExpr) TableSet { for idx, t2 := range st.Tables { - if t == t2.getAliasedTableExpr() { + if t == t2.GetAliasedTableExpr() { return SingleTableSet(idx) } } diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 3940a19d107..bcf0402433a 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -207,7 +207,7 @@ func newVindexTable(t sqlparser.IdentifierCS) *vindexes.Table { // The code lives in this file since it is only touching tableCollector data func (tc *tableCollector) tableSetFor(t *sqlparser.AliasedTableExpr) TableSet { for i, t2 := range tc.Tables { - if t == t2.getAliasedTableExpr() { + if t == t2.GetAliasedTableExpr() { return SingleTableSet(i) } } diff --git a/go/vt/vtgate/semantics/vindex_table.go b/go/vt/vtgate/semantics/vindex_table.go index f78e68cbd5b..fba8f8ab9a0 100644 --- a/go/vt/vtgate/semantics/vindex_table.go +++ b/go/vt/vtgate/semantics/vindex_table.go @@ -67,8 +67,8 @@ func (v *VindexTable) Name() (sqlparser.TableName, error) { } // GetExpr implements the TableInfo interface -func (v *VindexTable) getAliasedTableExpr() *sqlparser.AliasedTableExpr { - return v.Table.getAliasedTableExpr() +func (v *VindexTable) GetAliasedTableExpr() *sqlparser.AliasedTableExpr { + return v.Table.GetAliasedTableExpr() } func (v *VindexTable) canShortCut() shortCut { diff --git a/go/vt/vtgate/semantics/vtable.go b/go/vt/vtgate/semantics/vtable.go index 271da126cd4..133e38ff505 100644 --- a/go/vt/vtgate/semantics/vtable.go +++ b/go/vt/vtgate/semantics/vtable.go @@ -70,7 +70,7 @@ func (v *vTableInfo) Name() (sqlparser.TableName, error) { return sqlparser.TableName{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "oh noes") } -func (v *vTableInfo) getAliasedTableExpr() *sqlparser.AliasedTableExpr { +func (v *vTableInfo) GetAliasedTableExpr() *sqlparser.AliasedTableExpr { return nil } diff --git a/go/vt/vtgate/vindexes/vschema.go b/go/vt/vtgate/vindexes/vschema.go index ba703f31c22..c20f5561566 100644 --- a/go/vt/vtgate/vindexes/vschema.go +++ b/go/vt/vtgate/vindexes/vschema.go @@ -55,6 +55,7 @@ var TabletTypeSuffix = map[topodatapb.TabletType]string{ // The following constants represent table types. const ( + TypeTable = "" TypeSequence = "sequence" TypeReference = "reference" )