diff --git a/doc/design-docs/AtomicTransactionsWithDisruptions.md b/doc/design-docs/AtomicTransactionsWithDisruptions.md index 706308b6b2b..4ca60778436 100644 --- a/doc/design-docs/AtomicTransactionsWithDisruptions.md +++ b/doc/design-docs/AtomicTransactionsWithDisruptions.md @@ -35,3 +35,17 @@ Therefore, the safest option was to always check if we need to redo the prepared When Vttabet restarts, all the previous connections are dropped. It starts in a non-serving state, and then after reading the shard and tablet records from the topo, it transitions to a serving state. As part of this transition we need to ensure that we redo the prepared transactions before we start accepting any writes. This is done as part of the `TxEngine.transition` function when we transition to an `AcceptingReadWrite` state. We call the same code for redoing the prepared transactions that we called for MySQL restarts, PRS and ERS. + +## Online DDL + +During an Online DDL cutover, we need to ensure that all the prepared transactions on the online DDL table needs to be completed before we can proceed with the cutover. +This is because the cutover involves a schema change and we cannot have any prepared transactions that are dependent on the old schema. + +As part of the cut-over process, Online DDL adds query rules to buffer new queries on the table. +It then checks for any open prepared transaction on the table and waits for up to 100ms if found, then checks again. +If it finds no prepared transaction of the table, it moves forward with the cut-over, otherwise it fails. The Online DDL mechanism will later retry the cut-over. + +In the Prepare code, we check the query rules before adding the transaction to the prepared list and re-check the rules before storing the transaction logs in the transaction redo table. +Any transaction that went past the first check will fail the second check if the cutover proceeds. + +The check on both sides prevents either the cutover from proceeding or the transaction from being prepared. \ No newline at end of file diff --git a/go/test/endtoend/transaction/twopc/stress/fuzzer_test.go b/go/test/endtoend/transaction/twopc/stress/fuzzer_test.go index e81d0d0d9ab..932fcae1217 100644 --- a/go/test/endtoend/transaction/twopc/stress/fuzzer_test.go +++ b/go/test/endtoend/transaction/twopc/stress/fuzzer_test.go @@ -34,7 +34,9 @@ import ( "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/syscallutil" + "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/schema" ) var ( @@ -78,7 +80,7 @@ func TestTwoPCFuzzTest(t *testing.T) { threads int updateSets int timeForTesting time.Duration - clusterDisruptions []func() + clusterDisruptions []func(t *testing.T) disruptionProbability []int }{ { @@ -100,12 +102,12 @@ func TestTwoPCFuzzTest(t *testing.T) { timeForTesting: 5 * time.Second, }, { - name: "Multiple Threads - Multiple Set - PRS, ERS, and MySQL and Vttablet restart disruptions", + name: "Multiple Threads - Multiple Set - PRS, ERS, and MySQL & Vttablet restart, OnlineDDL disruptions", threads: 15, updateSets: 15, timeForTesting: 5 * time.Second, - clusterDisruptions: []func(){prs, ers, mysqlRestarts, vttabletRestarts}, - disruptionProbability: []int{5, 5, 5, 5}, + clusterDisruptions: []func(t *testing.T){prs, ers, mysqlRestarts, vttabletRestarts, onlineDDLFuzzer}, + disruptionProbability: []int{5, 5, 5, 5, 5}, }, } @@ -113,7 +115,7 @@ func TestTwoPCFuzzTest(t *testing.T) { t.Run(tt.name, func(t *testing.T) { conn, closer := start(t) defer closer() - fz := newFuzzer(tt.threads, tt.updateSets, tt.clusterDisruptions, tt.disruptionProbability) + fz := newFuzzer(t, tt.threads, tt.updateSets, tt.clusterDisruptions, tt.disruptionProbability) fz.initialize(t, conn) conn.Close() @@ -190,6 +192,7 @@ func getThreadIDsForUpdateSetFromFuzzInsert(t *testing.T, conn *mysql.Conn, upda type fuzzer struct { threads int updateSets int + t *testing.T // shouldStop is an internal state variable, that tells the fuzzer // whether it should stop or not. @@ -199,14 +202,15 @@ type fuzzer struct { // updateRowVals are the rows that we use to ensure 1 update on each shard with the same increment. updateRowsVals [][]int // clusterDisruptions are the cluster level disruptions that can happen in a running cluster. - clusterDisruptions []func() + clusterDisruptions []func(t *testing.T) // disruptionProbability is the chance for the disruption to happen. We check this every 100 milliseconds. disruptionProbability []int } // newFuzzer creates a new fuzzer struct. -func newFuzzer(threads int, updateSets int, clusterDisruptions []func(), disruptionProbability []int) *fuzzer { +func newFuzzer(t *testing.T, threads int, updateSets int, clusterDisruptions []func(t *testing.T), disruptionProbability []int) *fuzzer { fz := &fuzzer{ + t: t, threads: threads, updateSets: updateSets, wg: sync.WaitGroup{}, @@ -364,7 +368,7 @@ func (fz *fuzzer) runClusterDisruptionThread(t *testing.T) { func (fz *fuzzer) runClusterDisruption(t *testing.T) { for idx, prob := range fz.disruptionProbability { if rand.Intn(100) < prob { - fz.clusterDisruptions[idx]() + fz.clusterDisruptions[idx](fz.t) return } } @@ -374,7 +378,7 @@ func (fz *fuzzer) runClusterDisruption(t *testing.T) { Cluster Level Disruptions for the fuzzer */ -func prs() { +func prs(t *testing.T) { shards := clusterInstance.Keyspaces[0].Shards shard := shards[rand.Intn(len(shards))] vttablets := shard.Vttablets @@ -386,7 +390,7 @@ func prs() { } } -func ers() { +func ers(t *testing.T) { shards := clusterInstance.Keyspaces[0].Shards shard := shards[rand.Intn(len(shards))] vttablets := shard.Vttablets @@ -398,7 +402,7 @@ func ers() { } } -func vttabletRestarts() { +func vttabletRestarts(t *testing.T) { shards := clusterInstance.Keyspaces[0].Shards shard := shards[rand.Intn(len(shards))] vttablets := shard.Vttablets @@ -422,7 +426,27 @@ func vttabletRestarts() { } } -func mysqlRestarts() { +var orderedDDLFuzzer = []string{ + "alter table twopc_fuzzer_insert add column extra_col1 varchar(20)", + "alter table twopc_fuzzer_insert add column extra_col2 varchar(20)", + "alter table twopc_fuzzer_insert drop column extra_col1", + "alter table twopc_fuzzer_insert drop column extra_col2", +} + +// onlineDDLFuzzer runs an online DDL statement while ignoring any errors for the fuzzer. +func onlineDDLFuzzer(t *testing.T) { + output, err := clusterInstance.VtctldClientProcess.ApplySchemaWithOutput(keyspaceName, orderedDDLFuzzer[count%len(orderedDDLFuzzer)], cluster.ApplySchemaParams{ + DDLStrategy: "vitess --force-cut-over-after=1ms", + }) + count++ + if err != nil { + return + } + fmt.Println("Running online DDL with uuid: ", output) + WaitForMigrationStatus(t, &vtParams, clusterInstance.Keyspaces[0].Shards, strings.TrimSpace(output), 2*time.Minute, schema.OnlineDDLStatusComplete, schema.OnlineDDLStatusFailed) +} + +func mysqlRestarts(t *testing.T) { shards := clusterInstance.Keyspaces[0].Shards shard := shards[rand.Intn(len(shards))] vttablets := shard.Vttablets diff --git a/go/test/endtoend/transaction/twopc/stress/main_test.go b/go/test/endtoend/transaction/twopc/stress/main_test.go index 9c7ed28fa1a..ef8e454be15 100644 --- a/go/test/endtoend/transaction/twopc/stress/main_test.go +++ b/go/test/endtoend/transaction/twopc/stress/main_test.go @@ -71,6 +71,7 @@ func TestMain(m *testing.M) { clusterInstance.VtTabletExtraArgs = append(clusterInstance.VtTabletExtraArgs, "--twopc_enable", "--twopc_abandon_age", "1", + "--migration_check_interval", "2s", ) // Start keyspace diff --git a/go/test/endtoend/transaction/twopc/stress/stress_test.go b/go/test/endtoend/transaction/twopc/stress/stress_test.go index 9912bdf6e19..4dae0156b9d 100644 --- a/go/test/endtoend/transaction/twopc/stress/stress_test.go +++ b/go/test/endtoend/transaction/twopc/stress/stress_test.go @@ -34,9 +34,12 @@ import ( "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/syscallutil" + "vitess.io/vitess/go/test/endtoend/cluster" + "vitess.io/vitess/go/test/endtoend/onlineddl" twopcutil "vitess.io/vitess/go/test/endtoend/transaction/twopc/utils" "vitess.io/vitess/go/test/endtoend/utils" "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/schema" ) // TestDisruptions tests that atomic transactions persevere through various disruptions. @@ -44,12 +47,12 @@ func TestDisruptions(t *testing.T) { testcases := []struct { disruptionName string commitDelayTime string - disruption func() error + disruption func(t *testing.T) error }{ { disruptionName: "No Disruption", commitDelayTime: "1", - disruption: func() error { + disruption: func(t *testing.T) error { return nil }, }, @@ -68,6 +71,11 @@ func TestDisruptions(t *testing.T) { commitDelayTime: "5", disruption: vttabletRestartShard3, }, + { + disruptionName: "OnlineDDL", + commitDelayTime: "20", + disruption: onlineDDL, + }, { disruptionName: "EmergencyReparentShard", commitDelayTime: "5", @@ -119,7 +127,7 @@ func TestDisruptions(t *testing.T) { }() } // Run the disruption. - err := tt.disruption() + err := tt.disruption(t) require.NoError(t, err) // Wait for the commit to have returned. We don't actually check for an error in the commit because the user might receive an error. // But since we are waiting in CommitPrepared, the decision to commit the transaction should have already been taken. @@ -145,6 +153,7 @@ func threadToWrite(t *testing.T, ctx context.Context, id int) { continue } _, _ = utils.ExecAllowError(t, conn, fmt.Sprintf("insert into twopc_t1(id, col) values(%d, %d)", id, rand.Intn(10000))) + conn.Close() } } @@ -170,11 +179,13 @@ func waitForResults(t *testing.T, query string, resultExpected string, waitTime ctx := context.Background() conn, err := mysql.Connect(ctx, &vtParams) if err == nil { - res := utils.Exec(t, conn, query) + res, _ := utils.ExecAllowError(t, conn, query) conn.Close() - prevRes = res.Rows - if fmt.Sprintf("%v", res.Rows) == resultExpected { - return + if res != nil { + prevRes = res.Rows + if fmt.Sprintf("%v", res.Rows) == resultExpected { + return + } } } time.Sleep(100 * time.Millisecond) @@ -187,14 +198,14 @@ Cluster Level Disruptions for the fuzzer */ // prsShard3 runs a PRS in shard 3 of the keyspace. It promotes the second tablet to be the new primary. -func prsShard3() error { +func prsShard3(t *testing.T) error { shard := clusterInstance.Keyspaces[0].Shards[2] newPrimary := shard.Vttablets[1] return clusterInstance.VtctldClientProcess.PlannedReparentShard(keyspaceName, shard.Name, newPrimary.Alias) } // ersShard3 runs a ERS in shard 3 of the keyspace. It promotes the second tablet to be the new primary. -func ersShard3() error { +func ersShard3(t *testing.T) error { shard := clusterInstance.Keyspaces[0].Shards[2] newPrimary := shard.Vttablets[1] _, err := clusterInstance.VtctldClientProcess.ExecuteCommandWithOutput("EmergencyReparentShard", fmt.Sprintf("%s/%s", keyspaceName, shard.Name), "--new-primary", newPrimary.Alias) @@ -202,14 +213,14 @@ func ersShard3() error { } // vttabletRestartShard3 restarts the first vttablet of the third shard. -func vttabletRestartShard3() error { +func vttabletRestartShard3(t *testing.T) error { shard := clusterInstance.Keyspaces[0].Shards[2] tablet := shard.Vttablets[0] return tablet.RestartOnlyTablet() } // mysqlRestartShard3 restarts MySQL on the first tablet of the third shard. -func mysqlRestartShard3() error { +func mysqlRestartShard3(t *testing.T) error { shard := clusterInstance.Keyspaces[0].Shards[2] vttablets := shard.Vttablets tablet := vttablets[0] @@ -227,3 +238,76 @@ func mysqlRestartShard3() error { } return syscallutil.Kill(pid, syscall.SIGKILL) } + +var orderedDDL = []string{ + "alter table twopc_t1 add column extra_col1 varchar(20)", + "alter table twopc_t1 add column extra_col2 varchar(20)", + "alter table twopc_t1 add column extra_col3 varchar(20)", + "alter table twopc_t1 add column extra_col4 varchar(20)", +} + +var count = 0 + +// onlineDDL runs a DDL statement. +func onlineDDL(t *testing.T) error { + output, err := clusterInstance.VtctldClientProcess.ApplySchemaWithOutput(keyspaceName, orderedDDL[count%len(orderedDDL)], cluster.ApplySchemaParams{ + DDLStrategy: "vitess --force-cut-over-after=1ms", + }) + require.NoError(t, err) + count++ + fmt.Println("uuid: ", output) + status := WaitForMigrationStatus(t, &vtParams, clusterInstance.Keyspaces[0].Shards, strings.TrimSpace(output), 2*time.Minute, schema.OnlineDDLStatusComplete, schema.OnlineDDLStatusFailed) + onlineddl.CheckMigrationStatus(t, &vtParams, clusterInstance.Keyspaces[0].Shards, strings.TrimSpace(output), status) + require.Equal(t, schema.OnlineDDLStatusComplete, status) + return nil +} + +func WaitForMigrationStatus(t *testing.T, vtParams *mysql.ConnParams, shards []cluster.Shard, uuid string, timeout time.Duration, expectStatuses ...schema.OnlineDDLStatus) schema.OnlineDDLStatus { + shardNames := map[string]bool{} + for _, shard := range shards { + shardNames[shard.Name] = true + } + query := fmt.Sprintf("show vitess_migrations like '%s'", uuid) + + statusesMap := map[string]bool{} + for _, status := range expectStatuses { + statusesMap[string(status)] = true + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + lastKnownStatus := "" + for { + countMatchedShards := 0 + conn, err := mysql.Connect(ctx, vtParams) + if err != nil { + continue + } + r, err := utils.ExecAllowError(t, conn, query) + conn.Close() + if err != nil { + continue + } + for _, row := range r.Named().Rows { + shardName := row["shard"].ToString() + if !shardNames[shardName] { + // irrelevant shard + continue + } + lastKnownStatus = row["migration_status"].ToString() + if row["migration_uuid"].ToString() == uuid && statusesMap[lastKnownStatus] { + countMatchedShards++ + } + } + if countMatchedShards == len(shards) { + return schema.OnlineDDLStatus(lastKnownStatus) + } + select { + case <-ctx.Done(): + return schema.OnlineDDLStatus(lastKnownStatus) + case <-ticker.C: + } + } +} diff --git a/go/test/endtoend/transaction/twopc/utils/utils.go b/go/test/endtoend/transaction/twopc/utils/utils.go index b3b8796accf..9d0ef838eb6 100644 --- a/go/test/endtoend/transaction/twopc/utils/utils.go +++ b/go/test/endtoend/transaction/twopc/utils/utils.go @@ -61,9 +61,9 @@ func ClearOutTable(t *testing.T, vtParams mysql.ConnParams, tableName string) { return } _, err = conn.ExecuteFetch(fmt.Sprintf("DELETE FROM %v LIMIT 10000", tableName), 10000, false) + conn.Close() if err != nil { fmt.Printf("Error in cleanup deletion - %v\n", err) - conn.Close() time.Sleep(100 * time.Millisecond) continue } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index ae96fe9c1fe..ea7a6e93e0e 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -2818,3 +2818,24 @@ func (lock Lock) GetHighestOrderLock(newLock Lock) Lock { func Clone[K SQLNode](x K) K { return CloneSQLNode(x).(K) } + +// ExtractAllTables returns all the table names in the SQLNode as slice of string +func ExtractAllTables(stmt Statement) []string { + var tables []string + tableMap := make(map[string]any) + _ = Walk(func(node SQLNode) (kontinue bool, err error) { + switch node := node.(type) { + case *AliasedTableExpr: + if tblName, ok := node.Expr.(TableName); ok { + name := String(tblName) + if _, exists := tableMap[name]; !exists { + tableMap[name] = nil + tables = append(tables, name) + } + return false, nil + } + } + return true, nil + }, stmt) + return tables +} diff --git a/go/vt/sqlparser/ast_funcs_test.go b/go/vt/sqlparser/ast_funcs_test.go index 7bec47df96f..a3b744729f4 100644 --- a/go/vt/sqlparser/ast_funcs_test.go +++ b/go/vt/sqlparser/ast_funcs_test.go @@ -172,3 +172,50 @@ func TestColumns_Indexes(t *testing.T) { }) } } + +// TestExtractTables verifies the functionality of extracting all the tables from the SQLNode. +func TestExtractTables(t *testing.T) { + tcases := []struct { + sql string + expected []string + }{{ + sql: "select 1 from a", + expected: []string{"a"}, + }, { + sql: "select 1 from a, b", + expected: []string{"a", "b"}, + }, { + sql: "select 1 from a join b on a.id = b.id", + expected: []string{"a", "b"}, + }, { + sql: "select 1 from a join b on a.id = b.id join c on b.id = c.id", + expected: []string{"a", "b", "c"}, + }, { + sql: "select 1 from a join (select id from b) as c on a.id = c.id", + expected: []string{"a", "b"}, + }, { + sql: "(select 1 from a) union (select 1 from b)", + expected: []string{"a", "b"}, + }, { + sql: "select 1 from a where exists (select 1 from (select id from c) b where a.id = b.id)", + expected: []string{"a", "c"}, + }, { + sql: "select 1 from k.a join k.b on a.id = b.id", + expected: []string{"k.a", "k.b"}, + }, { + sql: "select 1 from k.a join l.a on k.a.id = l.a.id", + expected: []string{"k.a", "l.a"}, + }, { + sql: "select 1 from a join (select id from a) as c on a.id = c.id", + expected: []string{"a"}, + }} + parser := NewTestParser() + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + stmt, err := parser.Parse(tcase.sql) + require.NoError(t, err) + tables := ExtractAllTables(stmt) + require.Equal(t, tcase.expected, tables) + }) + } +} diff --git a/go/vt/vttablet/onlineddl/executor.go b/go/vt/vttablet/onlineddl/executor.go index a432581b3cd..f693dfe135c 100644 --- a/go/vt/vttablet/onlineddl/executor.go +++ b/go/vt/vttablet/onlineddl/executor.go @@ -178,6 +178,7 @@ type Executor struct { ts *topo.Server lagThrottler *throttle.Throttler toggleBufferTableFunc func(cancelCtx context.Context, tableName string, timeout time.Duration, bufferQueries bool) + isPreparedPoolEmpty func(tableName string) bool requestGCChecksFunc func() tabletAlias *topodatapb.TabletAlias @@ -238,6 +239,7 @@ func NewExecutor(env tabletenv.Env, tabletAlias *topodatapb.TabletAlias, ts *top tabletTypeFunc func() topodatapb.TabletType, toggleBufferTableFunc func(cancelCtx context.Context, tableName string, timeout time.Duration, bufferQueries bool), requestGCChecksFunc func(), + isPreparedPoolEmpty func(tableName string) bool, ) *Executor { // sanitize flags if maxConcurrentOnlineDDLs < 1 { @@ -255,6 +257,7 @@ func NewExecutor(env tabletenv.Env, tabletAlias *topodatapb.TabletAlias, ts *top ts: ts, lagThrottler: lagThrottler, toggleBufferTableFunc: toggleBufferTableFunc, + isPreparedPoolEmpty: isPreparedPoolEmpty, requestGCChecksFunc: requestGCChecksFunc, ticks: timer.NewTimer(migrationCheckInterval), // Gracefully return an error if any caller tries to execute @@ -870,7 +873,10 @@ func (e *Executor) killTableLockHoldersAndAccessors(ctx context.Context, tableNa threadId := row.AsInt64("trx_mysql_thread_id", 0) log.Infof("killTableLockHoldersAndAccessors: killing connection %v with transaction on table", threadId) killConnection := fmt.Sprintf("KILL %d", threadId) - _, _ = conn.Conn.ExecuteFetch(killConnection, 1, false) + _, err = conn.Conn.ExecuteFetch(killConnection, 1, false) + if err != nil { + log.Errorf("Unable to kill the connection %d: %v", threadId, err) + } } } } @@ -1102,6 +1108,11 @@ func (e *Executor) cutOverVReplMigration(ctx context.Context, s *VReplStream, sh time.Sleep(100 * time.Millisecond) if shouldForceCutOver { + // We should only proceed with forceful cut over if there is no pending atomic transaction for the table. + // This will help in keeping the atomicity guarantee of a prepared transaction. + if err := e.checkOnPreparedPool(ctx, onlineDDL.Table, 100*time.Millisecond); err != nil { + return err + } if err := e.killTableLockHoldersAndAccessors(ctx, onlineDDL.Table); err != nil { return err } @@ -5344,3 +5355,20 @@ func (e *Executor) OnSchemaMigrationStatus(ctx context.Context, return e.onSchemaMigrationStatus(ctx, uuidParam, status, dryRun, progressPct, etaSeconds, rowsCopied, hint) } + +func (e *Executor) checkOnPreparedPool(ctx context.Context, table string, waitTime time.Duration) error { + if e.isPreparedPoolEmpty(table) { + return nil + } + + select { + case <-ctx.Done(): + // Return context error if context is done + return ctx.Err() + case <-time.After(waitTime): + if e.isPreparedPoolEmpty(table) { + return nil + } + return vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "cannot force cut-over on non-empty prepared pool for table: %s", table) + } +} diff --git a/go/vt/vttablet/tabletserver/dt_executor.go b/go/vt/vttablet/tabletserver/dt_executor.go index 94e540c9a28..52e9366f322 100644 --- a/go/vt/vttablet/tabletserver/dt_executor.go +++ b/go/vt/vttablet/tabletserver/dt_executor.go @@ -24,7 +24,9 @@ import ( "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vttablet/tabletserver/rules" "vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv" "vitess.io/vitess/go/vt/vttablet/tabletserver/tx" ) @@ -34,13 +36,15 @@ type DTExecutor struct { ctx context.Context logStats *tabletenv.LogStats te *TxEngine + qe *QueryEngine } // NewDTExecutor creates a new distributed transaction executor. -func NewDTExecutor(ctx context.Context, te *TxEngine, logStats *tabletenv.LogStats) *DTExecutor { +func NewDTExecutor(ctx context.Context, te *TxEngine, qe *QueryEngine, logStats *tabletenv.LogStats) *DTExecutor { return &DTExecutor{ ctx: ctx, te: te, + qe: qe, logStats: logStats, } } @@ -83,12 +87,48 @@ func (dte *DTExecutor) Prepare(transactionID int64, dtid string) error { return vterrors.VT10002("cannot prepare the transaction on a reserved connection") } + // Fail Prepare if any query rule disallows it. + // This could be due to ongoing cutover happening in vreplication workflow + // regarding OnlineDDL or MoveTables. + for _, query := range conn.TxProperties().Queries { + qr := dte.qe.queryRuleSources.FilterByPlan(query.Sql, 0, query.Tables...) + if qr != nil { + act, _, _, _ := qr.GetAction("", "", nil, sqlparser.MarginComments{}) + if act != rules.QRContinue { + dte.te.txPool.RollbackAndRelease(dte.ctx, conn) + return vterrors.VT10002("cannot prepare the transaction due to query rule") + } + } + } + err = dte.te.preparedPool.Put(conn, dtid) if err != nil { dte.te.txPool.RollbackAndRelease(dte.ctx, conn) return vterrors.Errorf(vtrpcpb.Code_RESOURCE_EXHAUSTED, "prepare failed for transaction %d: %v", transactionID, err) } + // Recheck the rules. As some prepare transaction could have passed the first check. + // If they are put in the prepared pool, then vreplication workflow waits. + // This check helps reject the prepare that came later. + for _, query := range conn.TxProperties().Queries { + qr := dte.qe.queryRuleSources.FilterByPlan(query.Sql, 0, query.Tables...) + if qr != nil { + act, _, _, _ := qr.GetAction("", "", nil, sqlparser.MarginComments{}) + if act != rules.QRContinue { + dte.te.txPool.RollbackAndRelease(dte.ctx, conn) + dte.te.preparedPool.FetchForRollback(dtid) + return vterrors.VT10002("cannot prepare the transaction due to query rule") + } + } + } + + // If OnlineDDL killed the connection. We should avoid the prepare for it. + if conn.IsClosed() { + dte.te.txPool.RollbackAndRelease(dte.ctx, conn) + dte.te.preparedPool.FetchForRollback(dtid) + return vterrors.VT10002("cannot prepare the transaction on a closed connection") + } + return dte.inTransaction(func(localConn *StatefulConnection) error { return dte.te.twoPC.SaveRedo(dte.ctx, localConn, dtid, conn.TxProperties().Queries) }) diff --git a/go/vt/vttablet/tabletserver/dt_executor_test.go b/go/vt/vttablet/tabletserver/dt_executor_test.go index fb45ab454fc..f6369cfb978 100644 --- a/go/vt/vttablet/tabletserver/dt_executor_test.go +++ b/go/vt/vttablet/tabletserver/dt_executor_test.go @@ -26,6 +26,9 @@ import ( "time" "vitess.io/vitess/go/event/syslogger" + "vitess.io/vitess/go/vt/vtenv" + "vitess.io/vitess/go/vt/vttablet/tabletserver/rules" + "vitess.io/vitess/go/vt/vttablet/tabletserver/schema" "vitess.io/vitess/go/vt/vttablet/tabletserver/tx" "github.com/stretchr/testify/require" @@ -185,6 +188,60 @@ func TestTxExecutorPrepareRedoCommitFail(t *testing.T) { require.Contains(t, err.Error(), "commit fail") } +func TestExecutorPrepareRuleFailure(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + txe, tsv, db := newTestTxExecutor(t, ctx) + defer db.Close() + defer tsv.StopService() + + alterRule := rules.NewQueryRule("disable update", "disable update", rules.QRBuffer) + alterRule.AddTableCond("test_table") + + r := rules.New() + r.Add(alterRule) + txe.qe.queryRuleSources.RegisterSource("bufferQuery") + err := txe.qe.queryRuleSources.SetRules("bufferQuery", r) + require.NoError(t, err) + + // start a transaction + txid := newTxForPrep(ctx, tsv) + + // taint the connection. + sc, err := tsv.te.txPool.GetAndLock(txid, "adding query property") + require.NoError(t, err) + sc.txProps.Queries = append(sc.txProps.Queries, tx.Query{ + Sql: "update test_table set col = 5", + Tables: []string{"test_table"}, + }) + sc.Unlock() + + // try 2pc commit of Metadata Manager. + err = txe.Prepare(txid, "aa") + require.EqualError(t, err, "VT10002: atomic distributed transaction not allowed: cannot prepare the transaction due to query rule") +} + +func TestExecutorPrepareConnFailure(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + txe, tsv, db := newTestTxExecutor(t, ctx) + defer db.Close() + defer tsv.StopService() + + // start a transaction + txid := newTxForPrep(ctx, tsv) + + // taint the connection. + sc, err := tsv.te.txPool.GetAndLock(txid, "adding query property") + require.NoError(t, err) + sc.Unlock() + sc.dbConn.Close() + + // try 2pc commit of Metadata Manager. + err = txe.Prepare(txid, "aa") + require.EqualError(t, err, "VT10002: atomic distributed transaction not allowed: cannot prepare the transaction on a closed connection") +} + func TestTxExecutorCommit(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -610,6 +667,11 @@ func newTestTxExecutor(t *testing.T, ctx context.Context) (txe *DTExecutor, tsv db = setUpQueryExecutorTest(t) logStats := tabletenv.NewLogStats(ctx, "TestTxExecutor") tsv = newTestTabletServer(ctx, smallTxPool, db) + cfg := tabletenv.NewDefaultConfig() + cfg.DB = newDBConfigs(db) + env := tabletenv.NewEnv(vtenv.NewTestEnv(), cfg, "TabletServerTest") + se := schema.NewEngine(env) + qe := NewQueryEngine(env, se) db.AddQueryPattern("insert into _vt\\.redo_state\\(dtid, state, time_created\\) values \\('aa', 1,.*", &sqltypes.Result{}) db.AddQueryPattern("insert into _vt\\.redo_statement.*", &sqltypes.Result{}) db.AddQuery("delete from _vt.redo_state where dtid = 'aa'", &sqltypes.Result{}) @@ -619,6 +681,7 @@ func newTestTxExecutor(t *testing.T, ctx context.Context) (txe *DTExecutor, tsv ctx: ctx, logStats: logStats, te: tsv.te, + qe: qe, }, tsv, db } diff --git a/go/vt/vttablet/tabletserver/planbuilder/builder.go b/go/vt/vttablet/tabletserver/planbuilder/builder.go index 94f5fc1caa2..6df89f7caf8 100644 --- a/go/vt/vttablet/tabletserver/planbuilder/builder.go +++ b/go/vt/vttablet/tabletserver/planbuilder/builder.go @@ -33,7 +33,7 @@ func analyzeSelect(env *vtenv.Environment, sel *sqlparser.Select, tables map[str PlanID: PlanSelect, FullQuery: GenerateLimitQuery(sel), } - plan.Table, plan.AllTables = lookupTables(sel.From, tables) + plan.Table = lookupTables(sel.From, tables) if sel.Where != nil { comp, ok := sel.Where.Expr.(*sqlparser.ComparisonExpr) @@ -72,7 +72,7 @@ func analyzeUpdate(upd *sqlparser.Update, tables map[string]*schema.Table) (plan plan = &Plan{ PlanID: PlanUpdate, } - plan.Table, plan.AllTables = lookupTables(upd.TableExprs, tables) + plan.Table = lookupTables(upd.TableExprs, tables) // Store the WHERE clause as string for the hot row protection (txserializer). if upd.Where != nil { @@ -102,7 +102,7 @@ func analyzeDelete(del *sqlparser.Delete, tables map[string]*schema.Table) (plan plan = &Plan{ PlanID: PlanDelete, } - plan.Table, plan.AllTables = lookupTables(del.TableExprs, tables) + plan.Table = lookupTables(del.TableExprs, tables) if del.Where != nil { buf := sqlparser.NewTrackedBuffer(nil) @@ -127,11 +127,7 @@ func analyzeInsert(ins *sqlparser.Insert, tables map[string]*schema.Table) (plan FullQuery: GenerateFullQuery(ins), } - tableName, err := ins.Table.TableName() - if err != nil { - return nil, err - } - plan.Table = tables[sqlparser.GetTableName(tableName).String()] + plan.Table = lookupTables(sqlparser.TableExprs{ins.Table}, tables) return plan, nil } @@ -188,16 +184,26 @@ func analyzeSet(set *sqlparser.Set) (plan *Plan) { } } -func lookupTables(tableExprs sqlparser.TableExprs, tables map[string]*schema.Table) (singleTable *schema.Table, allTables []*schema.Table) { +func lookupTables(tableExprs sqlparser.TableExprs, tables map[string]*schema.Table) (singleTable *schema.Table) { for _, tableExpr := range tableExprs { if t := lookupSingleTable(tableExpr, tables); t != nil { - allTables = append(allTables, t) + if singleTable != nil { + return nil + } + singleTable = t } } - if len(allTables) == 1 { - singleTable = allTables[0] + return singleTable +} + +func lookupAllTables(stmt sqlparser.Statement, tables map[string]*schema.Table) (allTables []*schema.Table) { + tablesUsed := sqlparser.ExtractAllTables(stmt) + for _, tbl := range tablesUsed { + if t := tables[tbl]; t != nil { + allTables = append(allTables, t) + } } - return singleTable, allTables + return allTables } func lookupSingleTable(tableExpr sqlparser.TableExpr, tables map[string]*schema.Table) *schema.Table { @@ -229,12 +235,14 @@ func analyzeFlush(stmt *sqlparser.Flush, tables map[string]*schema.Table) (*Plan for _, tbl := range stmt.TableNames { if schemaTbl, ok := tables[tbl.Name.String()]; ok { - plan.AllTables = append(plan.AllTables, schemaTbl) + if plan.Table != nil { + // If there are multiple tables, we empty out the table field. + plan.Table = nil + break + } + plan.Table = schemaTbl } } - if len(plan.AllTables) == 1 { - plan.Table = plan.AllTables[0] - } if stmt.WithLock { plan.NeedsReservedConn = true diff --git a/go/vt/vttablet/tabletserver/planbuilder/plan.go b/go/vt/vttablet/tabletserver/planbuilder/plan.go index 7b1e57c2f90..f18ea59a714 100644 --- a/go/vt/vttablet/tabletserver/planbuilder/plan.go +++ b/go/vt/vttablet/tabletserver/planbuilder/plan.go @@ -157,7 +157,7 @@ type Plan struct { PlanID PlanType // When the query indicates a single table Table *schema.Table - // SELECT, UPDATE, DELETE statements may list multiple tables + // This indicates all the tables that are accessed in the query. AllTables []*schema.Table // Permissions stores the permissions for the tables accessed in the query. @@ -257,6 +257,7 @@ func Build(env *vtenv.Environment, statement sqlparser.Statement, tables map[str if err != nil { return nil, err } + plan.AllTables = lookupAllTables(statement, tables) plan.Permissions = BuildPermissions(statement) return plan, nil } @@ -274,14 +275,14 @@ func BuildStreaming(statement sqlparser.Statement, tables map[string]*schema.Tab if hasLockFunc(stmt) { plan.NeedsReservedConn = true } - plan.Table, plan.AllTables = lookupTables(stmt.From, tables) + plan.Table = lookupTables(stmt.From, tables) case *sqlparser.Show, *sqlparser.Union, *sqlparser.CallProc, sqlparser.Explain: case *sqlparser.Analyze: plan.PlanID = PlanOtherRead default: return nil, vterrors.Errorf(vtrpcpb.Code_FAILED_PRECONDITION, "%s not allowed for streaming", sqlparser.ASTToStatementType(statement)) } - + plan.AllTables = lookupAllTables(statement, tables) return plan, nil } diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index 02b8dd9171a..e89dee889dc 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -670,7 +670,6 @@ func (qre *QueryExecutor) execNextval() (*sqltypes.Result, error) { newLast += cache } query = fmt.Sprintf("update %s set next_id = %d where id = 0", sqlparser.String(tableName), newLast) - conn.TxProperties().RecordQuery(query) _, err = qre.execStatefulConn(conn, query, false) if err != nil { return nil, err @@ -807,7 +806,7 @@ func (qre *QueryExecutor) txFetch(conn *StatefulConnection, record bool) (*sqlty } // Only record successful queries. if record { - conn.TxProperties().RecordQuery(sql) + conn.TxProperties().RecordQueryDetail(sql, qre.plan.TableNames()) } return qr, nil } diff --git a/go/vt/vttablet/tabletserver/tabletserver.go b/go/vt/vttablet/tabletserver/tabletserver.go index 05466721224..4fa52796d97 100644 --- a/go/vt/vttablet/tabletserver/tabletserver.go +++ b/go/vt/vttablet/tabletserver/tabletserver.go @@ -186,7 +186,7 @@ func NewTabletServer(ctx context.Context, env *vtenv.Environment, name string, c tsv.messager = messager.NewEngine(tsv, tsv.se, tsv.vstreamer) tsv.tableGC = gc.NewTableGC(tsv, topoServer, tsv.lagThrottler) - tsv.onlineDDLExecutor = onlineddl.NewExecutor(tsv, alias, topoServer, tsv.lagThrottler, tabletTypeFunc, tsv.onlineDDLExecutorToggleTableBuffer, tsv.tableGC.RequestChecks) + tsv.onlineDDLExecutor = onlineddl.NewExecutor(tsv, alias, topoServer, tsv.lagThrottler, tabletTypeFunc, tsv.onlineDDLExecutorToggleTableBuffer, tsv.tableGC.RequestChecks, tsv.te.preparedPool.IsEmpty) tsv.sm = &stateManager{ statelessql: tsv.statelessql, @@ -641,7 +641,7 @@ func (tsv *TabletServer) Prepare(ctx context.Context, target *querypb.Target, tr "Prepare", "prepare", nil, target, nil, true, /* allowOnShutdown */ func(ctx context.Context, logStats *tabletenv.LogStats) error { - txe := NewDTExecutor(ctx, tsv.te, logStats) + txe := NewDTExecutor(ctx, tsv.te, tsv.qe, logStats) return txe.Prepare(transactionID, dtid) }, ) @@ -654,7 +654,7 @@ func (tsv *TabletServer) CommitPrepared(ctx context.Context, target *querypb.Tar "CommitPrepared", "commit_prepared", nil, target, nil, true, /* allowOnShutdown */ func(ctx context.Context, logStats *tabletenv.LogStats) error { - txe := NewDTExecutor(ctx, tsv.te, logStats) + txe := NewDTExecutor(ctx, tsv.te, tsv.qe, logStats) if DebugTwoPc { commitPreparedDelayForTest(tsv) } @@ -670,7 +670,7 @@ func (tsv *TabletServer) RollbackPrepared(ctx context.Context, target *querypb.T "RollbackPrepared", "rollback_prepared", nil, target, nil, true, /* allowOnShutdown */ func(ctx context.Context, logStats *tabletenv.LogStats) error { - txe := NewDTExecutor(ctx, tsv.te, logStats) + txe := NewDTExecutor(ctx, tsv.te, tsv.qe, logStats) return txe.RollbackPrepared(dtid, originalID) }, ) @@ -683,7 +683,7 @@ func (tsv *TabletServer) CreateTransaction(ctx context.Context, target *querypb. "CreateTransaction", "create_transaction", nil, target, nil, true, /* allowOnShutdown */ func(ctx context.Context, logStats *tabletenv.LogStats) error { - txe := NewDTExecutor(ctx, tsv.te, logStats) + txe := NewDTExecutor(ctx, tsv.te, tsv.qe, logStats) return txe.CreateTransaction(dtid, participants) }, ) @@ -697,7 +697,7 @@ func (tsv *TabletServer) StartCommit(ctx context.Context, target *querypb.Target "StartCommit", "start_commit", nil, target, nil, true, /* allowOnShutdown */ func(ctx context.Context, logStats *tabletenv.LogStats) error { - txe := NewDTExecutor(ctx, tsv.te, logStats) + txe := NewDTExecutor(ctx, tsv.te, tsv.qe, logStats) return txe.StartCommit(transactionID, dtid) }, ) @@ -711,7 +711,7 @@ func (tsv *TabletServer) SetRollback(ctx context.Context, target *querypb.Target "SetRollback", "set_rollback", nil, target, nil, true, /* allowOnShutdown */ func(ctx context.Context, logStats *tabletenv.LogStats) error { - txe := NewDTExecutor(ctx, tsv.te, logStats) + txe := NewDTExecutor(ctx, tsv.te, tsv.qe, logStats) return txe.SetRollback(dtid, transactionID) }, ) @@ -725,7 +725,7 @@ func (tsv *TabletServer) ConcludeTransaction(ctx context.Context, target *queryp "ConcludeTransaction", "conclude_transaction", nil, target, nil, true, /* allowOnShutdown */ func(ctx context.Context, logStats *tabletenv.LogStats) error { - txe := NewDTExecutor(ctx, tsv.te, logStats) + txe := NewDTExecutor(ctx, tsv.te, tsv.qe, logStats) return txe.ConcludeTransaction(dtid) }, ) @@ -738,7 +738,7 @@ func (tsv *TabletServer) ReadTransaction(ctx context.Context, target *querypb.Ta "ReadTransaction", "read_transaction", nil, target, nil, true, /* allowOnShutdown */ func(ctx context.Context, logStats *tabletenv.LogStats) error { - txe := NewDTExecutor(ctx, tsv.te, logStats) + txe := NewDTExecutor(ctx, tsv.te, tsv.qe, logStats) metadata, err = txe.ReadTransaction(dtid) return err }, @@ -753,7 +753,7 @@ func (tsv *TabletServer) UnresolvedTransactions(ctx context.Context, target *que "UnresolvedTransactions", "unresolved_transaction", nil, target, nil, false, /* allowOnShutdown */ func(ctx context.Context, logStats *tabletenv.LogStats) error { - txe := NewDTExecutor(ctx, tsv.te, logStats) + txe := NewDTExecutor(ctx, tsv.te, tsv.qe, logStats) transactions, err = txe.UnresolvedTransactions() return err }, @@ -1776,7 +1776,7 @@ func (tsv *TabletServer) registerQueryListHandlers(queryLists []*QueryList) { func (tsv *TabletServer) registerTwopczHandler() { tsv.exporter.HandleFunc("/twopcz", func(w http.ResponseWriter, r *http.Request) { ctx := tabletenv.LocalContext() - txe := NewDTExecutor(ctx, tsv.te, tabletenv.NewLogStats(ctx, "twopcz")) + txe := NewDTExecutor(ctx, tsv.te, tsv.qe, tabletenv.NewLogStats(ctx, "twopcz")) twopczHandler(txe, w, r) }) } diff --git a/go/vt/vttablet/tabletserver/tabletserver_test.go b/go/vt/vttablet/tabletserver/tabletserver_test.go index 7f863e26df7..e70b575f464 100644 --- a/go/vt/vttablet/tabletserver/tabletserver_test.go +++ b/go/vt/vttablet/tabletserver/tabletserver_test.go @@ -36,6 +36,7 @@ import ( "vitess.io/vitess/go/vt/callerid" "vitess.io/vitess/go/vt/sidecardb" "vitess.io/vitess/go/vt/vtenv" + "vitess.io/vitess/go/vt/vttablet/tabletserver/tx" "vitess.io/vitess/go/mysql/fakesqldb" "vitess.io/vitess/go/test/utils" @@ -241,7 +242,9 @@ func TestTabletServerRedoLogIsKeptBetweenRestarts(t *testing.T) { turnOnTxEngine() assert.EqualValues(t, 1, len(tsv.te.preparedPool.conns), "len(tsv.te.preparedPool.conns)") got := tsv.te.preparedPool.conns["dtid0"].TxProperties().Queries - want := []string{"update test_table set `name` = 2 where pk = 1 limit 10001"} + want := []tx.Query{{ + Sql: "update test_table set `name` = 2 where pk = 1 limit 10001", + Tables: []string{"test_table"}}} utils.MustMatch(t, want, got, "Prepared queries") turnOffTxEngine() assert.Empty(t, tsv.te.preparedPool.conns, "tsv.te.preparedPool.conns") @@ -275,7 +278,9 @@ func TestTabletServerRedoLogIsKeptBetweenRestarts(t *testing.T) { turnOnTxEngine() assert.EqualValues(t, 1, len(tsv.te.preparedPool.conns), "len(tsv.te.preparedPool.conns)") got = tsv.te.preparedPool.conns["a:b:10"].TxProperties().Queries - want = []string{"update test_table set `name` = 2 where pk = 1 limit 10001"} + want = []tx.Query{{ + Sql: "update test_table set `name` = 2 where pk = 1 limit 10001", + Tables: []string{"test_table"}}} utils.MustMatch(t, want, got, "Prepared queries") wantFailed := map[string]error{"a:b:20": errPrepFailed} utils.MustMatch(t, tsv.te.preparedPool.reserved, wantFailed, fmt.Sprintf("Failed dtids: %v, want %v", tsv.te.preparedPool.reserved, wantFailed)) diff --git a/go/vt/vttablet/tabletserver/twopc.go b/go/vt/vttablet/tabletserver/twopc.go index b3c5ab628c3..ffb4bae39a5 100644 --- a/go/vt/vttablet/tabletserver/twopc.go +++ b/go/vt/vttablet/tabletserver/twopc.go @@ -169,7 +169,7 @@ func (tpc *TwoPC) Close() { } // SaveRedo saves the statements in the redo log using the supplied connection. -func (tpc *TwoPC) SaveRedo(ctx context.Context, conn *StatefulConnection, dtid string, queries []string) error { +func (tpc *TwoPC) SaveRedo(ctx context.Context, conn *StatefulConnection, dtid string, queries []tx.Query) error { bindVars := map[string]*querypb.BindVariable{ "dtid": sqltypes.StringBindVariable(dtid), "state": sqltypes.Int64BindVariable(RedoStatePrepared), @@ -185,7 +185,7 @@ func (tpc *TwoPC) SaveRedo(ctx context.Context, conn *StatefulConnection, dtid s rows[i] = []sqltypes.Value{ sqltypes.NewVarBinary(dtid), sqltypes.NewInt64(int64(i + 1)), - sqltypes.NewVarBinary(query), + sqltypes.NewVarBinary(query.Sql), } } extras := map[string]sqlparser.Encodable{ diff --git a/go/vt/vttablet/tabletserver/tx/api.go b/go/vt/vttablet/tabletserver/tx/api.go index a392e530ffa..48a1cc1107a 100644 --- a/go/vt/vttablet/tabletserver/tx/api.go +++ b/go/vt/vttablet/tabletserver/tx/api.go @@ -31,11 +31,11 @@ type ( // ConnID as type int64 ConnID = int64 - //DTID as type string + // DTID as type string DTID = string - //EngineStateMachine is used to control the state the transactional engine - - //whether new connections and/or transactions are allowed or not. + // EngineStateMachine is used to control the state the transactional engine - + // whether new connections and/or transactions are allowed or not. EngineStateMachine interface { Init() error AcceptReadWrite() error @@ -46,14 +46,14 @@ type ( // ReleaseReason as type int ReleaseReason int - //Properties contains all information that is related to the currently running - //transaction on the connection + // Properties contains all information that is related to the currently running + // transaction on the connection Properties struct { EffectiveCaller *vtrpcpb.CallerID ImmediateCaller *querypb.VTGateCallerID StartTime time.Time EndTime time.Time - Queries []string + Queries []Query Autocommit bool Conclusion string LogToFile bool @@ -62,6 +62,11 @@ type ( } ) +type Query struct { + Sql string + Tables []string +} + const ( // TxClose - connection released on close. TxClose ReleaseReason = iota @@ -114,12 +119,33 @@ var txNames = map[ReleaseReason]string{ ConnRenewFail: "renewFail", } -// RecordQuery records the query against this transaction. -func (p *Properties) RecordQuery(query string) { +// RecordQueryDetail records the query and tables against this transaction. +func (p *Properties) RecordQueryDetail(query string, tables []string) { if p == nil { return } - p.Queries = append(p.Queries, query) + p.Queries = append(p.Queries, Query{ + Sql: query, + Tables: tables, + }) +} + +// RecordQuery records the query and extract tables against this transaction. +func (p *Properties) RecordQuery(query string, parser *sqlparser.Parser) { + if p == nil { + return + } + stmt, err := parser.Parse(query) + if err != nil { + // This should neven happen, but if it does, + // we would not be able to block cut-overs on this query. + return + } + tables := sqlparser.ExtractAllTables(stmt) + p.Queries = append(p.Queries, Query{ + Sql: query, + Tables: tables, + }) } // InTransaction returns true as soon as this struct is not nil @@ -134,10 +160,11 @@ func (p *Properties) String(sanitize bool, parser *sqlparser.Parser) string { printQueries := func() string { sb := strings.Builder{} for _, query := range p.Queries { + sql := query.Sql if sanitize { - query, _ = parser.RedactSQLQuery(query) + sql, _ = parser.RedactSQLQuery(sql) } - sb.WriteString(query) + sb.WriteString(sql) sb.WriteString(";") } return sb.String() diff --git a/go/vt/vttablet/tabletserver/tx_engine.go b/go/vt/vttablet/tabletserver/tx_engine.go index c465dc00578..d7f2d55b18a 100644 --- a/go/vt/vttablet/tabletserver/tx_engine.go +++ b/go/vt/vttablet/tabletserver/tx_engine.go @@ -122,7 +122,7 @@ func NewTxEngine(env tabletenv.Env, dxNotifier func()) *TxEngine { // perform metadata state change operations. Without this, // the system can deadlock if all connections get moved to // the TxPreparedPool. - te.preparedPool = NewTxPreparedPool(config.TxPool.Size - 2) + te.preparedPool = NewTxPreparedPool(config.TxPool.Size-2, te.twopcEnabled) readPool := connpool.NewPool(env, "TxReadPool", tabletenv.ConnPoolConfig{ Size: 3, IdleTimeout: env.Config().TxPool.IdleTimeout, @@ -434,7 +434,7 @@ outer: continue } for _, stmt := range preparedTx.Queries { - conn.TxProperties().RecordQuery(stmt) + conn.TxProperties().RecordQuery(stmt, te.env.Environment().Parser()) _, err := conn.Exec(ctx, stmt, 1, false) if err != nil { allErr.RecordError(vterrors.Wrapf(err, "dtid - %v", preparedTx.Dtid)) diff --git a/go/vt/vttablet/tabletserver/tx_prep_pool.go b/go/vt/vttablet/tabletserver/tx_prep_pool.go index c801e208e33..468c160c002 100644 --- a/go/vt/vttablet/tabletserver/tx_prep_pool.go +++ b/go/vt/vttablet/tabletserver/tx_prep_pool.go @@ -39,21 +39,45 @@ type TxPreparedPool struct { // open tells if the prepared pool is open for accepting transactions. open bool capacity int + // twoPCEnabled is set to true if 2PC is enabled. + twoPCEnabled bool } // NewTxPreparedPool creates a new TxPreparedPool. -func NewTxPreparedPool(capacity int) *TxPreparedPool { +func NewTxPreparedPool(capacity int, twoPCEnabled bool) *TxPreparedPool { if capacity < 0 { // If capacity is 0 all prepares will fail. capacity = 0 } return &TxPreparedPool{ - conns: make(map[string]*StatefulConnection, capacity), - reserved: make(map[string]error), - capacity: capacity, + conns: make(map[string]*StatefulConnection, capacity), + reserved: make(map[string]error), + capacity: capacity, + twoPCEnabled: twoPCEnabled, } } +// Open marks the prepared pool open for use. +func (pp *TxPreparedPool) Open() { + pp.mu.Lock() + defer pp.mu.Unlock() + pp.open = true +} + +// Close marks the prepared pool closed. +func (pp *TxPreparedPool) Close() { + pp.mu.Lock() + defer pp.mu.Unlock() + pp.open = false +} + +// IsOpen checks if the prepared pool is open for use. +func (pp *TxPreparedPool) IsOpen() bool { + pp.mu.Lock() + defer pp.mu.Unlock() + return pp.open +} + // Put adds the connection to the pool. It returns an error // if the pool is full or on duplicate key. func (pp *TxPreparedPool) Put(c *StatefulConnection, dtid string) error { @@ -93,27 +117,6 @@ func (pp *TxPreparedPool) FetchForRollback(dtid string) *StatefulConnection { return c } -// Open marks the prepared pool open for use. -func (pp *TxPreparedPool) Open() { - pp.mu.Lock() - defer pp.mu.Unlock() - pp.open = true -} - -// Close marks the prepared pool closed. -func (pp *TxPreparedPool) Close() { - pp.mu.Lock() - defer pp.mu.Unlock() - pp.open = false -} - -// IsOpen checks if the prepared pool is open for use. -func (pp *TxPreparedPool) IsOpen() bool { - pp.mu.Lock() - defer pp.mu.Unlock() - return pp.open -} - // FetchForCommit returns the connection for commit. Before returning, // it remembers the dtid in its reserved list as "committing". If // the dtid is already in the reserved list, it returns an error. @@ -169,3 +172,25 @@ func (pp *TxPreparedPool) FetchAllForRollback() []*StatefulConnection { pp.reserved = make(map[string]error) return conns } + +func (pp *TxPreparedPool) IsEmpty(tableName string) bool { + pp.mu.Lock() + defer pp.mu.Unlock() + if !pp.twoPCEnabled { + return true + } + // If the pool is shutdown, we do not know the correct state of prepared transactions. + if !pp.open { + return false + } + for _, connection := range pp.conns { + for _, query := range connection.txProps.Queries { + for _, table := range query.Tables { + if table == tableName { + return false + } + } + } + } + return true +} diff --git a/go/vt/vttablet/tabletserver/tx_prep_pool_test.go b/go/vt/vttablet/tabletserver/tx_prep_pool_test.go index 43c0c022b13..cf6d2b61093 100644 --- a/go/vt/vttablet/tabletserver/tx_prep_pool_test.go +++ b/go/vt/vttablet/tabletserver/tx_prep_pool_test.go @@ -112,7 +112,7 @@ func TestPrepFetchAll(t *testing.T) { // createAndOpenPreparedPool creates a new transaction prepared pool and opens it. // Used as a helper function for testing. func createAndOpenPreparedPool(capacity int) *TxPreparedPool { - pp := NewTxPreparedPool(capacity) + pp := NewTxPreparedPool(capacity, true) pp.Open() return pp } diff --git a/go/vt/vttablet/tabletserver/txlogz_test.go b/go/vt/vttablet/tabletserver/txlogz_test.go index 319669a0023..8faec74d07b 100644 --- a/go/vt/vttablet/tabletserver/txlogz_test.go +++ b/go/vt/vttablet/tabletserver/txlogz_test.go @@ -60,7 +60,7 @@ func testHandler(req *http.Request, t *testing.T) { ImmediateCaller: callerid.NewImmediateCallerID("immediate-caller"), StartTime: time.Now(), Conclusion: "unknown", - Queries: []string{"select * from test"}, + Queries: []tx.Query{{Sql: "select * from test"}}, }, } txConn.txProps.EndTime = txConn.txProps.StartTime