diff --git a/go/vt/vtctl/workflow/framework_test.go b/go/vt/vtctl/workflow/framework_test.go index 16bacc5f266..1d25aafa75f 100644 --- a/go/vt/vtctl/workflow/framework_test.go +++ b/go/vt/vtctl/workflow/framework_test.go @@ -254,6 +254,7 @@ type testTMClient struct { vrQueries map[int][]*queryResult createVReplicationWorkflowRequests map[uint32]*tabletmanagerdatapb.CreateVReplicationWorkflowRequest readVReplicationWorkflowRequests map[uint32]*tabletmanagerdatapb.ReadVReplicationWorkflowRequest + primaryPositions map[uint32]string env *testEnv // For access to the env config from tmc methods. reverse atomic.Bool // Are we reversing traffic? @@ -266,6 +267,7 @@ func newTestTMClient(env *testEnv) *testTMClient { vrQueries: make(map[int][]*queryResult), createVReplicationWorkflowRequests: make(map[uint32]*tabletmanagerdatapb.CreateVReplicationWorkflowRequest), readVReplicationWorkflowRequests: make(map[uint32]*tabletmanagerdatapb.ReadVReplicationWorkflowRequest), + primaryPositions: make(map[uint32]string), env: env, } } @@ -513,7 +515,21 @@ func (tmc *testTMClient) UpdateVReplicationWorkflow(ctx context.Context, tablet }, nil } +func (tmc *testTMClient) setPrimaryPosition(tablet *topodatapb.Tablet, position string) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + if tmc.primaryPositions == nil { + tmc.primaryPositions = make(map[uint32]string) + } + tmc.primaryPositions[tablet.Alias.Uid] = position +} + func (tmc *testTMClient) PrimaryPosition(ctx context.Context, tablet *topodatapb.Tablet) (string, error) { + tmc.mu.Lock() + defer tmc.mu.Unlock() + if tmc.primaryPositions != nil && tmc.primaryPositions[tablet.Alias.Uid] != "" { + return tmc.primaryPositions[tablet.Alias.Uid], nil + } return position, nil } diff --git a/go/vt/vtctl/workflow/traffic_switcher_test.go b/go/vt/vtctl/workflow/traffic_switcher_test.go index 5c0b2aba682..dfe394b2608 100644 --- a/go/vt/vtctl/workflow/traffic_switcher_test.go +++ b/go/vt/vtctl/workflow/traffic_switcher_test.go @@ -20,12 +20,15 @@ import ( "context" "fmt" "reflect" + "strconv" + "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" "vitess.io/vitess/go/vt/proto/vschema" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/vtgate/vindexes" @@ -361,3 +364,72 @@ func TestGetTargetSequenceMetadata(t *testing.T) { }) } } + +// TestSwitchTrafficPositionHandling confirms that if any writes are somehow +// executed against the source between the stop source writes and wait for +// catchup steps, that we have the correct position and do not lose the write(s). +func TestTrafficSwitchPositionHandling(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + workflowName := "wf1" + tableName := "t1" + sourceKeyspaceName := "sourceks" + targetKeyspaceName := "targetks" + + schema := map[string]*tabletmanagerdatapb.SchemaDefinition{ + tableName: { + TableDefinitions: []*tabletmanagerdatapb.TableDefinition{ + { + Name: tableName, + Schema: fmt.Sprintf("CREATE TABLE %s (id BIGINT, name VARCHAR(64), PRIMARY KEY (id))", tableName), + }, + }, + }, + } + + sourceKeyspace := &testKeyspace{ + KeyspaceName: sourceKeyspaceName, + ShardNames: []string{"0"}, + } + targetKeyspace := &testKeyspace{ + KeyspaceName: targetKeyspaceName, + ShardNames: []string{"0"}, + } + + env := newTestEnv(t, ctx, defaultCellName, sourceKeyspace, targetKeyspace) + defer env.close() + env.tmc.schema = schema + + ts, _, err := env.ws.getWorkflowState(ctx, targetKeyspaceName, workflowName) + require.NoError(t, err) + sw := &switcher{ts: ts, s: env.ws} + + lockCtx, sourceUnlock, lockErr := sw.lockKeyspace(ctx, ts.SourceKeyspaceName(), "test") + require.NoError(t, lockErr) + ctx = lockCtx + defer sourceUnlock(&err) + lockCtx, targetUnlock, lockErr := sw.lockKeyspace(ctx, ts.TargetKeyspaceName(), "test") + require.NoError(t, lockErr) + ctx = lockCtx + defer targetUnlock(&err) + + err = ts.stopSourceWrites(ctx) + require.NoError(t, err) + + // Now we simulate a write on the source. + newPosition := position[:strings.LastIndex(position, "-")+1] + oldSeqNo, err := strconv.Atoi(position[strings.LastIndex(position, "-")+1:]) + require.NoError(t, err) + newPosition = fmt.Sprintf("%s%d", newPosition, oldSeqNo+1) + env.tmc.setPrimaryPosition(env.tablets[sourceKeyspaceName][startingSourceTabletUID], newPosition) + + // And confirm that we picked up the new position. + err = ts.gatherSourcePositions(ctx) + require.NoError(t, err) + err = ts.ForAllSources(func(ms *MigrationSource) error { + require.Equal(t, newPosition, ms.Position) + return nil + }) + require.NoError(t, err) +}