Skip to content

Commit

Permalink
VReplication: handle escaped identifiers in vschema when initializing…
Browse files Browse the repository at this point in the history
… sequence tables (#16169)

Signed-off-by: Matt Lord <mattalord@gmail.com>
  • Loading branch information
mattlord authored Jun 18, 2024
1 parent 56aa1a6 commit ee01017
Show file tree
Hide file tree
Showing 10 changed files with 427 additions and 45 deletions.
6 changes: 3 additions & 3 deletions go/test/endtoend/vreplication/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ create table nopk (name varchar(128), age int unsigned);
],
"auto_increment": {
"column": "cid",
"sequence": "customer_seq"
"sequence": "` + "`customer_seq`" + `"
}
},
"customer_name": {
Expand Down Expand Up @@ -295,7 +295,7 @@ create table nopk (name varchar(128), age int unsigned);
],
"auto_increment": {
"column": "cid",
"sequence": "customer_seq"
"sequence": "` + "`product`.`customer_seq`" + `"
}
},
"orders": {
Expand Down Expand Up @@ -345,7 +345,7 @@ create table nopk (name varchar(128), age int unsigned);
],
"auto_increment": {
"column": "cid",
"sequence": "customer_seq"
"sequence": "` + "`customer_seq`" + `"
}
},
"orders": {
Expand Down
4 changes: 2 additions & 2 deletions go/test/endtoend/vreplication/fk_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ func doReshard(t *testing.T, keyspace, workflowName, sourceShards, targetShards
sourceShards: sourceShards,
targetShards: targetShards,
skipSchemaCopy: true,
}, workflowFlavorVtctl)
}, workflowFlavorVtctld)
rs.Create()
waitForWorkflowState(t, vc, fmt.Sprintf("%s.%s", keyspace, workflowName), binlogdatapb.VReplicationWorkflowState_Running.String())
for _, targetTab := range targetTabs {
Expand Down Expand Up @@ -355,7 +355,7 @@ func doMoveTables(t *testing.T, sourceKeyspace, targetKeyspace, workflowName, ta
},
sourceKeyspace: sourceKeyspace,
atomicCopy: atomicCopy,
}, workflowFlavorRandom)
}, workflowFlavorVtctld)
mt.Create()

waitForWorkflowState(t, vc, fmt.Sprintf("%s.%s", targetKeyspace, workflowName), binlogdatapb.VReplicationWorkflowState_Running.String())
Expand Down
2 changes: 1 addition & 1 deletion go/test/endtoend/vreplication/fk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
)

const testWorkflowFlavor = workflowFlavorRandom
const testWorkflowFlavor = workflowFlavorVtctld

// TestFKWorkflow runs a MoveTables workflow with atomic copy for a db with foreign key constraints.
// It inserts initial data, then simulates load. We insert both child rows with foreign keys and those without,
Expand Down
8 changes: 2 additions & 6 deletions go/test/endtoend/vreplication/partial_movetables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func testCancel(t *testing.T) {
sourceKeyspace: sourceKeyspace,
tables: table,
sourceShards: shard,
}, workflowFlavorRandom)
}, workflowFlavorVtctld)
mt.Create()

checkDenyList := func(keyspace string, expected bool) {
Expand Down Expand Up @@ -390,9 +390,5 @@ func testPartialMoveTablesBasic(t *testing.T, flavor workflowFlavor) {
// We test with both the vtctlclient and vtctldclient flavors.
func TestPartialMoveTablesBasic(t *testing.T) {
currentWorkflowType = binlogdatapb.VReplicationWorkflowType_MoveTables
for _, flavor := range workflowFlavors {
t.Run(workflowFlavorNames[flavor], func(t *testing.T) {
testPartialMoveTablesBasic(t, flavor)
})
}
testPartialMoveTablesBasic(t, workflowFlavorVtctld)
}
1 change: 1 addition & 0 deletions go/test/endtoend/vreplication/vdiff2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ func TestVDiff2(t *testing.T) {
// We ONLY add primary tablets in this test.
tks, err := vc.AddKeyspace(t, []*Cell{zone3, zone1, zone2}, targetKs, strings.Join(targetShards, ","), customerVSchema, customerSchema, 0, 0, 200, targetKsOpts)
require.NoError(t, err)
verifyClusterHealth(t, vc)

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
Expand Down
15 changes: 7 additions & 8 deletions go/test/endtoend/vreplication/vreplication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1563,17 +1563,16 @@ func switchWrites(t *testing.T, workflowType, ksWorkflow string, reverse bool) {
}
const SwitchWritesTimeout = "91s" // max: 3 tablet picker 30s waits + 1
ensureCanSwitch(t, workflowType, "", ksWorkflow)
// Use vtctldclient for MoveTables SwitchTraffic ~ 50% of the time.
if workflowType == binlogdatapb.VReplicationWorkflowType_MoveTables.String() && time.Now().Second()%2 == 0 {
parts := strings.Split(ksWorkflow, ".")
require.Equal(t, 2, len(parts))
moveTablesAction(t, command, defaultCellName, parts[1], sourceKs, parts[0], "", "--timeout="+SwitchWritesTimeout, "--tablet-types=primary")
targetKs, workflow, found := strings.Cut(ksWorkflow, ".")
require.True(t, found)
if workflowType == binlogdatapb.VReplicationWorkflowType_MoveTables.String() {
moveTablesAction(t, command, defaultCellName, workflow, sourceKs, targetKs, "", "--timeout="+SwitchWritesTimeout, "--tablet-types=primary")
return
}
output, err := vc.VtctlClient.ExecuteCommandWithOutput(workflowType, "--", "--tablet_types=primary",
"--timeout="+SwitchWritesTimeout, "--initialize-target-sequences", command, ksWorkflow)
output, err := vc.VtctldClient.ExecuteCommandWithOutput(workflowType, "--tablet-types=primary", "--workflow", workflow,
"--target-keyspace", targetKs, command, "--timeout="+SwitchWritesTimeout, "--initialize-target-sequences")
if output != "" {
fmt.Printf("Output of switching writes with vtctlclient for %s:\n++++++\n%s\n--------\n", ksWorkflow, output)
fmt.Printf("Output of switching writes with vtctldclient for %s:\n++++++\n%s\n--------\n", ksWorkflow, output)
}
// printSwitchWritesExtraDebug is useful when debugging failures in Switch writes due to corner cases/races
_ = printSwitchWritesExtraDebug
Expand Down
127 changes: 104 additions & 23 deletions go/vt/vtctl/workflow/traffic_switcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,12 @@ func (ts *trafficSwitcher) getTargetSequenceMetadata(ctx context.Context) (map[s
return nil
}
for tableName, tableDef := range kvs.Tables {
// The table name can be escaped in the vschema definition.
unescapedTableName, err := sqlescape.UnescapeID(tableName)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid table name %s in keyspace %s: %v",
tableName, keyspace, err)
}
select {
case <-sctx.Done():
return sctx.Err()
Expand All @@ -1396,9 +1402,9 @@ func (ts *trafficSwitcher) getTargetSequenceMetadata(ctx context.Context) (map[s
if complete := func() bool {
smMu.Lock() // Prevent concurrent access to the map
defer smMu.Unlock()
sm := sequencesByBackingTable[tableName]
sm := sequencesByBackingTable[unescapedTableName]
if tableDef != nil && tableDef.Type == vindexes.TypeSequence &&
sm != nil && tableName == sm.backingTableName {
sm != nil && unescapedTableName == sm.backingTableName {
tablesFound++ // This is also protected by the mutex
sm.backingTableKeyspace = keyspace
// Set the default keyspace name. We will later check to
Expand Down Expand Up @@ -1429,18 +1435,22 @@ func (ts *trafficSwitcher) getTargetSequenceMetadata(ctx context.Context) (map[s
searchGroup, gctx := errgroup.WithContext(ctx)
searchCompleted := make(chan struct{})
for _, keyspace := range keyspaces {
keyspace := keyspace // https://golang.org/doc/faq#closures_and_goroutines
// The keyspace name could be escaped so we need to unescape it.
ks, err := sqlescape.UnescapeID(keyspace)
if err != nil { // Should never happen
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid keyspace name %s: %v", keyspace, err)
}
searchGroup.Go(func() error {
return searchKeyspace(gctx, searchCompleted, keyspace)
return searchKeyspace(gctx, searchCompleted, ks)
})
}
if err := searchGroup.Wait(); err != nil {
return nil, err
}

if tablesFound != tableCount {
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to locate all of the backing sequence tables being used; sequence table metadata: %+v",
sequencesByBackingTable)
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to locate all of the backing sequence tables being used: %s",
strings.Join(maps.Keys(sequencesByBackingTable), ","))
}
return sequencesByBackingTable, nil
}
Expand All @@ -1460,34 +1470,73 @@ func (ts *trafficSwitcher) findSequenceUsageInKeyspace(vschema *vschemapb.Keyspa
targetDBName := targets[0].GetPrimary().DbName()
sequencesByBackingTable := make(map[string]*sequenceMetadata)

for _, table := range ts.Tables() {
vs, ok := vschema.Tables[table]
if !ok || vs.GetAutoIncrement().GetSequence() == "" {
for _, table := range ts.tables {
seqTable, ok := vschema.Tables[table]
if !ok || seqTable.GetAutoIncrement().GetSequence() == "" {
continue
}
// Be sure that the table name is unescaped as it can be escaped
// in the vschema.
unescapedTable, err := sqlescape.UnescapeID(table)
if err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid table name %s defined in the sequence table %+v: %v",
table, seqTable, err)
}
sm := &sequenceMetadata{
backingTableName: vs.AutoIncrement.Sequence,
usingTableName: table,
usingTableDefinition: vs,
usingTableDBName: targetDBName,
usingTableName: unescapedTable,
usingTableDBName: targetDBName,
}
// If the sequence table is fully qualified in the vschema then
// we don't need to find it later.
if strings.Contains(vs.AutoIncrement.Sequence, ".") {
keyspace, tableName, found := strings.Cut(vs.AutoIncrement.Sequence, ".")
if strings.Contains(seqTable.AutoIncrement.Sequence, ".") {
keyspace, tableName, found := strings.Cut(seqTable.AutoIncrement.Sequence, ".")
if !found {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence table name %s defined in the %s keyspace",
vs.AutoIncrement.Sequence, ts.targetKeyspace)
seqTable.AutoIncrement.Sequence, ts.targetKeyspace)
}
// Unescape the table name and keyspace name as they may be escaped in the
// vschema definition if they e.g. contain dashes.
if keyspace, err = sqlescape.UnescapeID(keyspace); err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid keyspace in qualified sequence table name %s defined in sequence table %+v: %v",
seqTable.AutoIncrement.Sequence, seqTable, err)
}
if tableName, err = sqlescape.UnescapeID(tableName); err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid qualified sequence table name %s defined in sequence table %+v: %v",
seqTable.AutoIncrement.Sequence, seqTable, err)
}
sm.backingTableName = tableName
sm.backingTableKeyspace = keyspace
sm.backingTableName = tableName
// Update the definition with the unescaped values.
seqTable.AutoIncrement.Sequence = fmt.Sprintf("%s.%s", keyspace, tableName)
// Set the default keyspace name. We will later check to
// see if the tablet we send requests to is using a dbname
// override and use that if it is.
sm.backingTableDBName = "vt_" + keyspace
} else {
sm.backingTableName, err = sqlescape.UnescapeID(seqTable.AutoIncrement.Sequence)
if err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence table name %s defined in sequence table %+v: %v",
seqTable.AutoIncrement.Sequence, seqTable, err)
}
seqTable.AutoIncrement.Sequence = sm.backingTableName
allFullyQualified = false
}
// The column names can be escaped in the vschema definition.
for i := range seqTable.ColumnVindexes {
unescapedColumn, err := sqlescape.UnescapeID(seqTable.ColumnVindexes[i].Column)
if err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence column vindex name %s defined in sequence table %+v: %v",
seqTable.ColumnVindexes[i].Column, seqTable, err)
}
seqTable.ColumnVindexes[i].Column = unescapedColumn
}
unescapedAutoIncCol, err := sqlescape.UnescapeID(seqTable.AutoIncrement.Column)
if err != nil {
return nil, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid auto-increment column name %s defined in sequence table %+v: %v",
seqTable.AutoIncrement.Column, seqTable, err)
}
seqTable.AutoIncrement.Column = unescapedAutoIncCol
sm.usingTableDefinition = seqTable
sequencesByBackingTable[sm.backingTableName] = sm
}

Expand Down Expand Up @@ -1516,10 +1565,25 @@ func (ts *trafficSwitcher) initializeTargetSequences(ctx context.Context, sequen
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "no primary tablet found for target shard %s/%s",
ts.targetKeyspace, target.GetShard().ShardName())
}
usingCol, err := sqlescape.EnsureEscaped(sequenceMetadata.usingTableDefinition.AutoIncrement.Column)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid column name %s specified for sequence in table %s: %v",
sequenceMetadata.usingTableDefinition.AutoIncrement.Column, sequenceMetadata.usingTableName, err)
}
usingDB, err := sqlescape.EnsureEscaped(sequenceMetadata.usingTableDBName)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid database name %s specified for sequence in table %s: %v",
sequenceMetadata.usingTableDBName, sequenceMetadata.usingTableName, err)
}
usingTable, err := sqlescape.EnsureEscaped(sequenceMetadata.usingTableName)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence table name specified for sequence in table %s: %v",
sequenceMetadata.usingTableName, err)
}
query := sqlparser.BuildParsedQuery(sqlGetMaxSequenceVal,
sqlescape.EscapeID(sequenceMetadata.usingTableDefinition.AutoIncrement.Column),
sqlescape.EscapeID(sequenceMetadata.usingTableDBName),
sqlescape.EscapeID(sequenceMetadata.usingTableName),
usingCol,
usingDB,
usingTable,
)
qr, terr := ts.ws.tmc.ExecuteFetchAsApp(ictx, primary.Tablet, true, &tabletmanagerdatapb.ExecuteFetchAsAppRequest{
Query: []byte(query.Query),
Expand Down Expand Up @@ -1580,9 +1644,19 @@ func (ts *trafficSwitcher) initializeTargetSequences(ctx context.Context, sequen
if sequenceTablet.DbNameOverride != "" {
sequenceMetadata.backingTableDBName = sequenceTablet.DbNameOverride
}
backingDB, err := sqlescape.EnsureEscaped(sequenceMetadata.backingTableDBName)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid database name %s in sequence backing table %s: %v",
sequenceMetadata.backingTableDBName, sequenceMetadata.backingTableName, err)
}
backingTable, err := sqlescape.EnsureEscaped(sequenceMetadata.backingTableName)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence backing table name %s: %v",
sequenceMetadata.backingTableName, err)
}
query := sqlparser.BuildParsedQuery(sqlInitSequenceTable,
sqlescape.EscapeID(sequenceMetadata.backingTableDBName),
sqlescape.EscapeID(sequenceMetadata.backingTableName),
backingDB,
backingTable,
nextVal,
nextVal,
nextVal,
Expand Down Expand Up @@ -1615,7 +1689,14 @@ func (ts *trafficSwitcher) initializeTargetSequences(ctx context.Context, sequen
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to get primary tablet for keyspace %s: %v",
sequenceMetadata.backingTableKeyspace, ierr)
}
ierr = ts.TabletManagerClient().ResetSequences(ictx, ti.Tablet, []string{sequenceMetadata.backingTableName})
// ResetSequences interfaces with the schema engine and the actual
// table identifiers DO NOT contain the backticks. So we have to
// ensure that the table name is unescaped.
unescapedBackingTable, err := sqlescape.UnescapeID(backingTable)
if err != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid sequence backing table name %s: %v", backingTable, err)
}
ierr = ts.TabletManagerClient().ResetSequences(ictx, ti.Tablet, []string{unescapedBackingTable})
if ierr != nil {
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to reset the sequence cache for backing table %s on shard %s/%s using tablet %s: %v",
sequenceMetadata.backingTableName, sequenceShard.Keyspace(), sequenceShard.ShardName(), sequenceShard.PrimaryAlias, ierr)
Expand Down
Loading

0 comments on commit ee01017

Please sign in to comment.