diff --git a/go/vt/vtgate/planbuilder/operators/cte_merging.go b/go/vt/vtgate/planbuilder/operators/cte_merging.go index cb19e06b2a7..0c1556c81e4 100644 --- a/go/vt/vtgate/planbuilder/operators/cte_merging.go +++ b/go/vt/vtgate/planbuilder/operators/cte_merging.go @@ -31,7 +31,7 @@ func tryMergeRecurse(ctx *plancontext.PlanningContext, in *RecurseCTE) (Operator } func tryMergeCTE(ctx *plancontext.PlanningContext, seed, term Operator, in *RecurseCTE) *Route { - seedRoute, termRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(seed, term) + seedRoute, termRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, seed, term) if seedRoute == nil { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/delete.go b/go/vt/vtgate/planbuilder/operators/delete.go index 81e36d54315..9cc7ffc2381 100644 --- a/go/vt/vtgate/planbuilder/operators/delete.go +++ b/go/vt/vtgate/planbuilder/operators/delete.go @@ -124,7 +124,7 @@ func createDeleteWithInputOp(ctx *plancontext.PlanningContext, del *sqlparser.De } var delOps []dmlOp - for _, target := range ctx.SemTable.Targets.Constituents() { + for _, target := range ctx.SemTable.DMLTargets.Constituents() { op := createDeleteOpWithTarget(ctx, target, del.Ignore) delOps = append(delOps, op) } diff --git a/go/vt/vtgate/planbuilder/operators/join_merging.go b/go/vt/vtgate/planbuilder/operators/join_merging.go index c035b7d11ed..707f41d6f51 100644 --- a/go/vt/vtgate/planbuilder/operators/join_merging.go +++ b/go/vt/vtgate/planbuilder/operators/join_merging.go @@ -28,7 +28,7 @@ import ( // If they can be merged, a new operator with the merged routing is returned // If they cannot be merged, nil is returned. func (jm *joinMerger) mergeJoinInputs(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredicates []sqlparser.Expr) *Route { - lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(lhs, rhs) + lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, lhs, rhs) if lhsRoute == nil { return nil } @@ -102,13 +102,13 @@ func mergeAnyShardRoutings(ctx *plancontext.PlanningContext, a, b *AnyShardRouti } } -func prepareInputRoutes(lhs Operator, rhs Operator) (*Route, *Route, Routing, Routing, routingType, routingType, bool) { +func prepareInputRoutes(ctx *plancontext.PlanningContext, lhs Operator, rhs Operator) (*Route, *Route, Routing, Routing, routingType, routingType, bool) { lhsRoute, rhsRoute := operatorsToRoutes(lhs, rhs) if lhsRoute == nil || rhsRoute == nil { return nil, nil, nil, nil, 0, 0, false } - lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(lhsRoute, rhsRoute) + lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(ctx, lhsRoute, rhsRoute) a, b := getRoutingType(routingA), getRoutingType(routingB) return lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace @@ -159,7 +159,7 @@ func (rt routingType) String() string { // getRoutesOrAlternates gets the Routings from each Route. If they are from different keyspaces, // we check if this is a table with alternates in other keyspaces that we can use -func getRoutesOrAlternates(lhsRoute, rhsRoute *Route) (*Route, *Route, Routing, Routing, bool) { +func getRoutesOrAlternates(ctx *plancontext.PlanningContext, lhsRoute, rhsRoute *Route) (*Route, *Route, Routing, Routing, bool) { routingA := lhsRoute.Routing routingB := rhsRoute.Routing sameKeyspace := routingA.Keyspace() == routingB.Keyspace() @@ -171,13 +171,15 @@ func getRoutesOrAlternates(lhsRoute, rhsRoute *Route) (*Route, *Route, Routing, return lhsRoute, rhsRoute, routingA, routingB, sameKeyspace } - if refA, ok := routingA.(*AnyShardRouting); ok { + if refA, ok := routingA.(*AnyShardRouting); ok && + !TableID(lhsRoute).IsOverlapping(ctx.SemTable.DMLTargets) { if altARoute := refA.AlternateInKeyspace(routingB.Keyspace()); altARoute != nil { return altARoute, rhsRoute, altARoute.Routing, routingB, true } } - if refB, ok := routingB.(*AnyShardRouting); ok { + if refB, ok := routingB.(*AnyShardRouting); ok && + !TableID(rhsRoute).IsOverlapping(ctx.SemTable.DMLTargets) { if altBRoute := refB.AlternateInKeyspace(routingA.Keyspace()); altBRoute != nil { return lhsRoute, altBRoute, routingA, altBRoute.Routing, true } diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index a2aca74fb6e..e222ae0f343 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -730,7 +730,7 @@ func mergeSubqueryInputs(ctx *plancontext.PlanningContext, in, out Operator, joi return nil } - inRoute, outRoute, inRouting, outRouting, sameKeyspace := getRoutesOrAlternates(inRoute, outRoute) + inRoute, outRoute, inRouting, outRouting, sameKeyspace := getRoutesOrAlternates(ctx, inRoute, outRoute) inner, outer := getRoutingType(inRouting), getRoutingType(outRouting) switch { diff --git a/go/vt/vtgate/planbuilder/operators/union_merging.go b/go/vt/vtgate/planbuilder/operators/union_merging.go index 000d176b61a..6173b59e0dc 100644 --- a/go/vt/vtgate/planbuilder/operators/union_merging.go +++ b/go/vt/vtgate/planbuilder/operators/union_merging.go @@ -108,7 +108,7 @@ func mergeUnionInputs( lhsExprs, rhsExprs sqlparser.SelectExprs, distinct bool, ) (Operator, sqlparser.SelectExprs) { - lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(lhs, rhs) + lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(ctx, lhs, rhs) if lhsRoute == nil { return nil, nil } diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index dd0a86c2de2..18a81175f7b 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -164,7 +164,7 @@ func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Up ueMap := prepareUpdateExpressionList(ctx, upd) var updOps []dmlOp - for _, target := range ctx.SemTable.Targets.Constituents() { + for _, target := range ctx.SemTable.DMLTargets.Constituents() { op := createUpdateOpWithTarget(ctx, upd, target, ueMap[target]) updOps = append(updOps, op) } @@ -308,7 +308,7 @@ func errIfUpdateNotSupported(ctx *plancontext.PlanningContext, stmt *sqlparser.U } } - // Now we check if any of the foreign key columns that are being udpated have dependencies on other updated columns. + // Now we check if any of the foreign key columns that are being updated have dependencies on other updated columns. // This is unsafe, and we currently don't support this in Vitess. if err := ctx.SemTable.ErrIfFkDependentColumnUpdated(stmt.Exprs); err != nil { panic(err) diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 7135f4dff29..f3bed93e3c8 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -84,6 +84,7 @@ func (s *planTestSuite) TestPlan() { s.addPKsProvided(vschema, "user", []string{"user_extra"}, []string{"id", "user_id"}) s.addPKsProvided(vschema, "ordering", []string{"order"}, []string{"oid", "region_id"}) s.addPKsProvided(vschema, "ordering", []string{"order_event"}, []string{"oid", "ename"}) + s.addPKsProvided(vschema, "main", []string{"source_of_ref"}, []string{"id"}) // You will notice that some tests expect user.Id instead of user.id. // This is because we now pre-create vindex columns in the symbol @@ -305,6 +306,7 @@ func (s *planTestSuite) TestOne() { s.addPKsProvided(vschema, "user", []string{"user_extra"}, []string{"id", "user_id"}) s.addPKsProvided(vschema, "ordering", []string{"order"}, []string{"oid", "region_id"}) s.addPKsProvided(vschema, "ordering", []string{"order_event"}, []string{"oid", "ename"}) + s.addPKsProvided(vschema, "main", []string{"source_of_ref"}, []string{"id"}) s.testFile("onecase.json", vw, false) } diff --git a/go/vt/vtgate/planbuilder/testdata/reference_cases.json b/go/vt/vtgate/planbuilder/testdata/reference_cases.json index 6aa01355934..a379af52788 100644 --- a/go/vt/vtgate/planbuilder/testdata/reference_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/reference_cases.json @@ -771,5 +771,75 @@ "user.user_extra" ] } + }, + { + "comment": "update reference table with join with sharded table", + "query": "update main.source_of_ref as sr join main.rerouted_ref as rr on sr.id = rr.id inner join user.music as m on sr.col = m.col set sr.tt = 5 where m.user_id = 1", + "plan": { + "QueryType": "UPDATE", + "Original": "update main.source_of_ref as sr join main.rerouted_ref as rr on sr.id = rr.id inner join user.music as m on sr.col = m.col set sr.tt = 5 where m.user_id = 1", + "Instructions": { + "OperatorType": "DMLWithInput", + "TargetTabletType": "PRIMARY", + "Offset": [ + "0:[0]" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "R:0", + "JoinVars": { + "m_col": 0 + }, + "TableName": "music_rerouted_ref, source_of_ref", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select m.col from music as m where 1 != 1", + "Query": "select m.col from music as m where m.user_id = 1 lock in share mode", + "Table": "music", + "Values": [ + "1" + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select sr.id from source_of_ref as sr, rerouted_ref as rr where 1 != 1", + "Query": "select sr.id from source_of_ref as sr, rerouted_ref as rr where sr.col = :m_col and sr.id = rr.id lock in share mode", + "Table": "rerouted_ref, source_of_ref" + } + ] + }, + { + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update source_of_ref as sr set sr.tt = 5 where sr.id in ::dml_vals", + "Table": "source_of_ref" + } + ] + }, + "TablesUsed": [ + "main.rerouted_ref", + "main.source_of_ref", + "user.music" + ] + } } ] diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 0a9d2480d9b..25ab19a9947 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -174,7 +174,7 @@ func (a *analyzer) newSemTable( Direct: a.binder.direct, ExprTypes: a.typer.m, Tables: a.tables.Tables, - Targets: a.binder.targets, + DMLTargets: a.binder.targets, NotSingleRouteErr: a.notSingleRouteErr, NotUnshardedErr: a.unshardedErr, Warning: a.warning, diff --git a/go/vt/vtgate/semantics/semantic_table.go b/go/vt/vtgate/semantics/semantic_table.go index f9856a901a6..9e2a3703669 100644 --- a/go/vt/vtgate/semantics/semantic_table.go +++ b/go/vt/vtgate/semantics/semantic_table.go @@ -129,8 +129,8 @@ type ( // It doesn't recurse inside derived tables to find the original dependencies. Direct ExprDependencies - // Targets contains the TableSet of each table getting modified by the update/delete statement. - Targets TableSet + // DMLTargets contains the TableSet of each table getting modified by the update/delete statement. + DMLTargets TableSet // ColumnEqualities is used for transitive closures (e.g., if a == b and b == c, then a == c). ColumnEqualities map[columnName][]sqlparser.Expr @@ -202,7 +202,7 @@ func (st *SemTable) CopyDependencies(from, to sqlparser.Expr) { // GetChildForeignKeysForTargets gets the child foreign keys as a list for all the target tables. func (st *SemTable) GetChildForeignKeysForTargets() (fks []vindexes.ChildFKInfo) { - for _, ts := range st.Targets.Constituents() { + for _, ts := range st.DMLTargets.Constituents() { fks = append(fks, st.childForeignKeysInvolved[ts]...) } return fks @@ -210,7 +210,7 @@ func (st *SemTable) GetChildForeignKeysForTargets() (fks []vindexes.ChildFKInfo) // GetChildForeignKeysForTableSet gets the child foreign keys as a listfor the TableSet. func (st *SemTable) GetChildForeignKeysForTableSet(target TableSet) (fks []vindexes.ChildFKInfo) { - for _, ts := range st.Targets.Constituents() { + for _, ts := range st.DMLTargets.Constituents() { if target.IsSolvedBy(ts) { fks = append(fks, st.childForeignKeysInvolved[ts]...) } @@ -238,7 +238,7 @@ func (st *SemTable) GetChildForeignKeysList() []vindexes.ChildFKInfo { // GetParentForeignKeysForTargets gets the parent foreign keys as a list for all the target tables. func (st *SemTable) GetParentForeignKeysForTargets() (fks []vindexes.ParentFKInfo) { - for _, ts := range st.Targets.Constituents() { + for _, ts := range st.DMLTargets.Constituents() { fks = append(fks, st.parentForeignKeysInvolved[ts]...) } return fks @@ -246,7 +246,7 @@ func (st *SemTable) GetParentForeignKeysForTargets() (fks []vindexes.ParentFKInf // GetParentForeignKeysForTableSet gets the parent foreign keys as a list for the TableSet. func (st *SemTable) GetParentForeignKeysForTableSet(target TableSet) (fks []vindexes.ParentFKInfo) { - for _, ts := range st.Targets.Constituents() { + for _, ts := range st.DMLTargets.Constituents() { if target.IsSolvedBy(ts) { fks = append(fks, st.parentForeignKeysInvolved[ts]...) } @@ -970,7 +970,7 @@ func (st *SemTable) UpdateChildFKExpr(origUpdExpr *sqlparser.UpdateExpr, newExpr // GetTargetTableSetForTableName returns the TableSet for the given table name from the target tables. func (st *SemTable) GetTargetTableSetForTableName(name sqlparser.TableName) (TableSet, error) { - for _, target := range st.Targets.Constituents() { + for _, target := range st.DMLTargets.Constituents() { tbl, err := st.Tables[target.TableOffset()].Name() if err != nil { return "", err