Skip to content

Commit

Permalink
Add timeout to all the contexts used for RPC calls in vtorc (#15991)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <manan@planetscale.com>
  • Loading branch information
GuptaManan100 authored May 22, 2024
1 parent 2283f6b commit a1edaee
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 6 deletions.
16 changes: 15 additions & 1 deletion go/vt/vtctl/grpcvtctldserver/testutil/test_tmclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ type TabletManagerClient struct {
}
// keyed by tablet alias.
ChangeTabletTypeResult map[string]error
ChangeTabletTypeDelays map[string]time.Duration
// keyed by tablet alias.
DemotePrimaryDelays map[string]time.Duration
// keyed by tablet alias.
Expand Down Expand Up @@ -468,7 +469,20 @@ func (fake *TabletManagerClient) Backup(ctx context.Context, tablet *topodatapb.

// ChangeType is part of the tmclient.TabletManagerClient interface.
func (fake *TabletManagerClient) ChangeType(ctx context.Context, tablet *topodatapb.Tablet, newType topodatapb.TabletType, semiSync bool) error {
if result, ok := fake.ChangeTabletTypeResult[topoproto.TabletAliasString(tablet.Alias)]; ok {
key := topoproto.TabletAliasString(tablet.Alias)

if fake.ChangeTabletTypeDelays != nil {
if delay, ok := fake.ChangeTabletTypeDelays[key]; ok {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
// proceed to results
}
}
}

if result, ok := fake.ChangeTabletTypeResult[key]; ok {
return result
}

Expand Down
20 changes: 15 additions & 5 deletions go/vt/vtorc/logic/tablet_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,27 +285,37 @@ func LockShard(ctx context.Context, tabletAlias string, lockAction string) (cont

// tabletUndoDemotePrimary calls the said RPC for the given tablet.
func tabletUndoDemotePrimary(ctx context.Context, tablet *topodatapb.Tablet, semiSync bool) error {
return tmc.UndoDemotePrimary(ctx, tablet, semiSync)
tmcCtx, tmcCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout)
defer tmcCancel()
return tmc.UndoDemotePrimary(tmcCtx, tablet, semiSync)
}

// setReadOnly calls the said RPC for the given tablet
func setReadOnly(ctx context.Context, tablet *topodatapb.Tablet) error {
return tmc.SetReadOnly(ctx, tablet)
tmcCtx, tmcCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout)
defer tmcCancel()
return tmc.SetReadOnly(tmcCtx, tablet)
}

// changeTabletType calls the said RPC for the given tablet with the given parameters.
func changeTabletType(ctx context.Context, tablet *topodatapb.Tablet, tabletType topodatapb.TabletType, semiSync bool) error {
return tmc.ChangeType(ctx, tablet, tabletType, semiSync)
tmcCtx, tmcCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout)
defer tmcCancel()
return tmc.ChangeType(tmcCtx, tablet, tabletType, semiSync)
}

// resetReplicationParameters resets the replication parameters on the given tablet.
func resetReplicationParameters(ctx context.Context, tablet *topodatapb.Tablet) error {
return tmc.ResetReplicationParameters(ctx, tablet)
tmcCtx, tmcCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout)
defer tmcCancel()
return tmc.ResetReplicationParameters(tmcCtx, tablet)
}

// setReplicationSource calls the said RPC with the parameters provided
func setReplicationSource(ctx context.Context, replica *topodatapb.Tablet, primary *topodatapb.Tablet, semiSync bool, heartbeatInterval float64) error {
return tmc.SetReplicationSource(ctx, replica, primary.Alias, 0, "", true, semiSync, heartbeatInterval)
tmcCtx, tmcCancel := context.WithTimeout(ctx, topo.RemoteOperationTimeout)
defer tmcCancel()
return tmc.SetReplicationSource(tmcCtx, replica, primary.Alias, 0, "", true, semiSync, heartbeatInterval)
}

// shardPrimary finds the primary of the given keyspace-shard by reading the vtorc backend
Expand Down
259 changes: 259 additions & 0 deletions go/vt/vtorc/logic/tablet_discovery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"sync/atomic"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
Expand All @@ -30,8 +31,10 @@ import (
"vitess.io/vitess/go/vt/external/golib/sqlutils"
topodatapb "vitess.io/vitess/go/vt/proto/topodata"
"vitess.io/vitess/go/vt/proto/vttime"
"vitess.io/vitess/go/vt/topo"
"vitess.io/vitess/go/vt/topo/memorytopo"
"vitess.io/vitess/go/vt/topo/topoproto"
"vitess.io/vitess/go/vt/vtctl/grpcvtctldserver/testutil"
"vitess.io/vitess/go/vt/vtorc/db"
"vitess.io/vitess/go/vt/vtorc/inst"
"vitess.io/vitess/go/vt/vtorc/process"
Expand Down Expand Up @@ -362,3 +365,259 @@ func TestProcessHealth(t *testing.T) {
_, discoveredOnce = process.HealthTest()
require.True(t, discoveredOnce)
}

func TestSetReadOnly(t *testing.T) {
tests := []struct {
name string
tablet *topodatapb.Tablet
tmc *testutil.TabletManagerClient
remoteOpTimeout time.Duration
errShouldContain string
}{
{
name: "Success",
tablet: tab100,
tmc: &testutil.TabletManagerClient{
SetReadOnlyResults: map[string]error{
"zone-1-0000000100": nil,
},
},
}, {
name: "Failure",
tablet: tab100,
tmc: &testutil.TabletManagerClient{
SetReadOnlyResults: map[string]error{
"zone-1-0000000100": fmt.Errorf("testing error"),
},
},
errShouldContain: "testing error",
}, {
name: "Timeout",
tablet: tab100,
remoteOpTimeout: 100 * time.Millisecond,
tmc: &testutil.TabletManagerClient{
SetReadOnlyResults: map[string]error{
"zone-1-0000000100": nil,
},
SetReadOnlyDelays: map[string]time.Duration{
"zone-1-0000000100": 200 * time.Millisecond,
},
},
errShouldContain: "context deadline exceeded",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oldTmc := tmc
oldRemoteOpTimeout := topo.RemoteOperationTimeout
defer func() {
tmc = oldTmc
topo.RemoteOperationTimeout = oldRemoteOpTimeout
}()

tmc = tt.tmc
if tt.remoteOpTimeout != 0 {
topo.RemoteOperationTimeout = tt.remoteOpTimeout
}

err := setReadOnly(context.Background(), tt.tablet)
if tt.errShouldContain == "" {
require.NoError(t, err)
return
}
require.ErrorContains(t, err, tt.errShouldContain)
})
}
}

func TestTabletUndoDemotePrimary(t *testing.T) {
tests := []struct {
name string
tablet *topodatapb.Tablet
tmc *testutil.TabletManagerClient
remoteOpTimeout time.Duration
errShouldContain string
}{
{
name: "Success",
tablet: tab100,
tmc: &testutil.TabletManagerClient{
UndoDemotePrimaryResults: map[string]error{
"zone-1-0000000100": nil,
},
},
}, {
name: "Failure",
tablet: tab100,
tmc: &testutil.TabletManagerClient{
UndoDemotePrimaryResults: map[string]error{
"zone-1-0000000100": fmt.Errorf("testing error"),
},
},
errShouldContain: "testing error",
}, {
name: "Timeout",
tablet: tab100,
remoteOpTimeout: 100 * time.Millisecond,
tmc: &testutil.TabletManagerClient{
UndoDemotePrimaryResults: map[string]error{
"zone-1-0000000100": nil,
},
UndoDemotePrimaryDelays: map[string]time.Duration{
"zone-1-0000000100": 200 * time.Millisecond,
},
},
errShouldContain: "context deadline exceeded",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oldTmc := tmc
oldRemoteOpTimeout := topo.RemoteOperationTimeout
defer func() {
tmc = oldTmc
topo.RemoteOperationTimeout = oldRemoteOpTimeout
}()

tmc = tt.tmc
if tt.remoteOpTimeout != 0 {
topo.RemoteOperationTimeout = tt.remoteOpTimeout
}

err := tabletUndoDemotePrimary(context.Background(), tt.tablet, false)
if tt.errShouldContain == "" {
require.NoError(t, err)
return
}
require.ErrorContains(t, err, tt.errShouldContain)
})
}
}

func TestChangeTabletType(t *testing.T) {
tests := []struct {
name string
tablet *topodatapb.Tablet
tmc *testutil.TabletManagerClient
remoteOpTimeout time.Duration
errShouldContain string
}{
{
name: "Success",
tablet: tab100,
tmc: &testutil.TabletManagerClient{
ChangeTabletTypeResult: map[string]error{
"zone-1-0000000100": nil,
},
},
}, {
name: "Failure",
tablet: tab100,
tmc: &testutil.TabletManagerClient{
ChangeTabletTypeResult: map[string]error{
"zone-1-0000000100": fmt.Errorf("testing error"),
},
},
errShouldContain: "testing error",
}, {
name: "Timeout",
tablet: tab100,
remoteOpTimeout: 100 * time.Millisecond,
tmc: &testutil.TabletManagerClient{
ChangeTabletTypeResult: map[string]error{
"zone-1-0000000100": nil,
},
ChangeTabletTypeDelays: map[string]time.Duration{
"zone-1-0000000100": 200 * time.Millisecond,
},
},
errShouldContain: "context deadline exceeded",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oldTmc := tmc
oldRemoteOpTimeout := topo.RemoteOperationTimeout
defer func() {
tmc = oldTmc
topo.RemoteOperationTimeout = oldRemoteOpTimeout
}()

tmc = tt.tmc
if tt.remoteOpTimeout != 0 {
topo.RemoteOperationTimeout = tt.remoteOpTimeout
}

err := changeTabletType(context.Background(), tt.tablet, topodatapb.TabletType_REPLICA, false)
if tt.errShouldContain == "" {
require.NoError(t, err)
return
}
require.ErrorContains(t, err, tt.errShouldContain)
})
}
}

func TestSetReplicationSource(t *testing.T) {
tests := []struct {
name string
tablet *topodatapb.Tablet
tmc *testutil.TabletManagerClient
remoteOpTimeout time.Duration
errShouldContain string
}{
{
name: "Success",
tablet: tab100,
tmc: &testutil.TabletManagerClient{
SetReplicationSourceResults: map[string]error{
"zone-1-0000000100": nil,
},
},
}, {
name: "Failure",
tablet: tab100,
tmc: &testutil.TabletManagerClient{
SetReplicationSourceResults: map[string]error{
"zone-1-0000000100": fmt.Errorf("testing error"),
},
},
errShouldContain: "testing error",
}, {
name: "Timeout",
tablet: tab100,
remoteOpTimeout: 100 * time.Millisecond,
tmc: &testutil.TabletManagerClient{
SetReplicationSourceResults: map[string]error{
"zone-1-0000000100": nil,
},
SetReplicationSourceDelays: map[string]time.Duration{
"zone-1-0000000100": 200 * time.Millisecond,
},
},
errShouldContain: "context deadline exceeded",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
oldTmc := tmc
oldRemoteOpTimeout := topo.RemoteOperationTimeout
defer func() {
tmc = oldTmc
topo.RemoteOperationTimeout = oldRemoteOpTimeout
}()

tmc = tt.tmc
if tt.remoteOpTimeout != 0 {
topo.RemoteOperationTimeout = tt.remoteOpTimeout
}

err := setReplicationSource(context.Background(), tt.tablet, tab101, false, 0)
if tt.errShouldContain == "" {
require.NoError(t, err)
return
}
require.ErrorContains(t, err, tt.errShouldContain)
})
}
}

0 comments on commit a1edaee

Please sign in to comment.