Skip to content

Commit

Permalink
Add savepoint support to atomic distributed transaction (#16863)
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <harshit@planetscale.com>
Signed-off-by: Manan Gupta <manan@planetscale.com>
Co-authored-by: Manan Gupta <manan@planetscale.com>
  • Loading branch information
harshit-gangal and GuptaManan100 authored Oct 10, 2024
1 parent d75272c commit a7b903b
Show file tree
Hide file tree
Showing 18 changed files with 778 additions and 157 deletions.
11 changes: 11 additions & 0 deletions go/test/endtoend/cluster/cluster_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,17 @@ func (shard *Shard) PrimaryTablet() *Vttablet {
return shard.Vttablets[0]
}

// FindPrimaryTablet finds the primary tablet in the shard.
func (shard *Shard) FindPrimaryTablet() *Vttablet {
for _, vttablet := range shard.Vttablets {
tabletType := vttablet.VttabletProcess.GetTabletType()
if tabletType == "primary" {
return vttablet
}
}
return nil
}

// Rdonly get the last tablet which is rdonly
func (shard *Shard) Rdonly() *Vttablet {
for idx, tablet := range shard.Vttablets {
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/cluster/reshard.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (rw *ReshardWorkflow) WaitForVreplCatchup(timeToWait time.Duration) {
if !slices.Contains(targetShards, shard.Name) {
continue
}
vttablet := shard.PrimaryTablet().VttabletProcess
vttablet := shard.FindPrimaryTablet().VttabletProcess
vttablet.WaitForVReplicationToCatchup(rw.t, rw.workflowName, fmt.Sprintf("vt_%s", vttablet.Keyspace), "", timeToWait)
}
}
Expand Down
57 changes: 55 additions & 2 deletions go/test/endtoend/transaction/twopc/fuzz/fuzzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"os"
"path"
"slices"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -126,7 +127,18 @@ func TestTwoPCFuzzTest(t *testing.T) {
fz.start(t)

// Wait for the timeForTesting so that the threads continue to run.
time.Sleep(tt.timeForTesting)
timeout := time.After(tt.timeForTesting)
loop := true
for loop {
select {
case <-timeout:
loop = false
case <-time.After(1 * time.Second):
if t.Failed() {
loop = false
}
}
}

// Signal the fuzzer to stop.
fz.stop()
Expand Down Expand Up @@ -302,9 +314,11 @@ func (fz *fuzzer) generateAndExecuteTransaction(threadId int) {
// for each update set ordered by the auto increment column will not be true.
// That assertion depends on all the transactions running updates first to ensure that for any given update set,
// no two transactions are running the insert queries.
queries := []string{"begin"}
var queries []string
queries = append(queries, fz.generateUpdateQueries(updateSetVal, incrementVal)...)
queries = append(queries, fz.generateInsertQueries(updateSetVal, threadId)...)
queries = fz.addRandomSavePoints(queries)
queries = append([]string{"begin"}, queries...)
finalCommand := "commit"
for _, query := range queries {
_, err := conn.ExecuteFetch(query, 0, false)
Expand Down Expand Up @@ -377,6 +391,45 @@ func (fz *fuzzer) runClusterDisruption(t *testing.T) {
}
}

// addRandomSavePoints will add random savepoints and queries to the list of queries.
// It still ensures that all the new queries added are rolledback so that the assertions of queries
// don't change.
func (fz *fuzzer) addRandomSavePoints(queries []string) []string {
savePointCount := 1
for {
shouldAddSavePoint := rand.Intn(2)
if shouldAddSavePoint == 0 {
return queries
}

savePointQueries := []string{"SAVEPOINT sp" + strconv.Itoa(savePointCount)}
randomDmlCount := rand.Intn(2) + 1
for i := 0; i < randomDmlCount; i++ {
savePointQueries = append(savePointQueries, fz.randomDML())
}
savePointQueries = append(savePointQueries, "ROLLBACK TO sp"+strconv.Itoa(savePointCount))
savePointCount++

savePointPosition := rand.Intn(len(queries))
newQueries := slices.Clone(queries[:savePointPosition])
newQueries = append(newQueries, savePointQueries...)
newQueries = append(newQueries, queries[savePointPosition:]...)
queries = newQueries
}
}

// randomDML generates a random DML to be used.
func (fz *fuzzer) randomDML() string {
queryType := rand.Intn(2)
if queryType == 0 {
// Generate INSERT
return fmt.Sprintf(insertIntoFuzzInsert, updateRowBaseVals[rand.Intn(len(updateRowBaseVals))], rand.Intn(fz.updateSets), rand.Intn(fz.threads))
}
// Generate UPDATE
updateId := fz.updateRowsVals[rand.Intn(len(fz.updateRowsVals))][rand.Intn(len(updateRowBaseVals))]
return fmt.Sprintf(updateFuzzUpdate, rand.Intn(100000), updateId)
}

/*
Cluster Level Disruptions for the fuzzer
*/
Expand Down
77 changes: 73 additions & 4 deletions go/test/endtoend/transaction/twopc/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@ import (
"fmt"
"io"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/endtoend/utils"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/transaction/twopc/utils"
twopcutil "vitess.io/vitess/go/test/endtoend/transaction/twopc/utils"
binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
querypb "vitess.io/vitess/go/vt/proto/query"
topodatapb "vitess.io/vitess/go/vt/proto/topodata"
Expand All @@ -42,6 +45,7 @@ import (

var (
clusterInstance *cluster.LocalProcessCluster
mysqlParams mysql.ConnParams
vtParams mysql.ConnParams
vtgateGrpcAddress string
keyspaceName = "ks"
Expand Down Expand Up @@ -81,6 +85,8 @@ func TestMain(m *testing.M) {
"--twopc_enable",
"--twopc_abandon_age", "1",
"--queryserver-config-transaction-cap", "3",
"--queryserver-config-transaction-timeout", "400s",
"--queryserver-config-query-timeout", "9000s",
)

// Start keyspace
Expand All @@ -102,6 +108,15 @@ func TestMain(m *testing.M) {
vtParams = clusterInstance.GetVTParams(keyspaceName)
vtgateGrpcAddress = fmt.Sprintf("%s:%d", clusterInstance.Hostname, clusterInstance.VtgateGrpcPort)

// create mysql instance and connection parameters
conn, closer, err := utils.NewMySQL(clusterInstance, keyspaceName, SchemaSQL)
if err != nil {
fmt.Println(err)
return 1
}
defer closer()
mysqlParams = conn

return m.Run()
}()
os.Exit(exitcode)
Expand All @@ -121,8 +136,29 @@ func start(t *testing.T) (*mysql.Conn, func()) {

func cleanup(t *testing.T) {
cluster.PanicHandler(t)
utils.ClearOutTable(t, vtParams, "twopc_user")
utils.ClearOutTable(t, vtParams, "twopc_t1")
twopcutil.ClearOutTable(t, vtParams, "twopc_user")
twopcutil.ClearOutTable(t, vtParams, "twopc_t1")
sm.reset()
}

func startWithMySQL(t *testing.T) (utils.MySQLCompare, func()) {
mcmp, err := utils.NewMySQLCompare(t, vtParams, mysqlParams)
require.NoError(t, err)

deleteAll := func() {
tables := []string{"twopc_user"}
for _, table := range tables {
_, _ = mcmp.ExecAndIgnore("delete from " + table)
}
}

deleteAll()

return mcmp, func() {
deleteAll()
mcmp.Close()
cluster.PanicHandler(t)
}
}

type extractInterestingValues func(dtidMap map[string]string, vals []sqltypes.Value) []sqltypes.Value
Expand All @@ -147,7 +183,8 @@ var tables = map[string]extractInterestingValues{
},
"ks.redo_statement": func(dtidMap map[string]string, vals []sqltypes.Value) (out []sqltypes.Value) {
dtid := getDTID(dtidMap, vals[0].ToString())
out = append([]sqltypes.Value{sqltypes.NewVarChar(dtid)}, vals[1:]...)
stmt := getStatement(vals[2].ToString())
out = append([]sqltypes.Value{sqltypes.NewVarChar(dtid)}, vals[1], sqltypes.TestValue(sqltypes.Blob, stmt))
return
},
"ks.twopc_user": func(_ map[string]string, vals []sqltypes.Value) []sqltypes.Value { return vals },
Expand All @@ -167,6 +204,28 @@ func getDTID(dtidMap map[string]string, dtKey string) string {
return dtid
}

func getStatement(stmt string) string {
var sKey string
var prefix string
switch {
case strings.HasPrefix(stmt, "savepoint"):
prefix = "savepoint-"
sKey = stmt[9:]
case strings.HasPrefix(stmt, "rollback to"):
prefix = "rollback-"
sKey = stmt[11:]
default:
return stmt
}

sid, exists := sm.stmt[sKey]
if !exists {
sid = fmt.Sprintf("%d", len(sm.stmt)+1)
sm.stmt[sKey] = sid
}
return prefix + sid
}

func runVStream(t *testing.T, ctx context.Context, ch chan *binlogdatapb.VEvent, vtgateConn *vtgateconn.VTGateConn) {
vgtid := &binlogdatapb.VGtid{
ShardGtids: []*binlogdatapb.ShardGtid{
Expand Down Expand Up @@ -272,3 +331,13 @@ func prettyPrint(v interface{}) string {
}
return string(b)
}

type stmtMapper struct {
stmt map[string]string
}

var sm = &stmtMapper{stmt: make(map[string]string)}

func (sm *stmtMapper) reset() {
sm.stmt = make(map[string]string)
}
Loading

0 comments on commit a7b903b

Please sign in to comment.