Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
Signed-off-by: MyonKeminta <MyonKeminta@users.noreply.github.com>
  • Loading branch information
MyonKeminta committed Nov 6, 2024
1 parent 59d5568 commit 19c6546
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 2 deletions.
2 changes: 1 addition & 1 deletion oracle/oracles/pd.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op

func (o *pdOracle) ValidateSnapshotReadTS(ctx context.Context, readTS uint64, opt *oracle.Option) error {
latestTS, err := o.GetLowResolutionTimestamp(ctx, opt)
// If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double check.
// If we fail to get latestTS or the readTS exceeds it, get a timestamp from PD to double-check.
// But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function
// loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls.
if err != nil || readTS > latestTS {
Expand Down
159 changes: 158 additions & 1 deletion oracle/oracles/pd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,15 @@ func TestNonFutureStaleTSO(t *testing.T) {
}
}

func TestNextUpdateTSInterval(t *testing.T) {
func TestAdaptiveUpdateTSInterval(t *testing.T) {
oracleInterface, err := NewPdOracle(&MockPdClient{}, &PDOracleOptions{
UpdateInterval: time.Second * 2,
NoUpdateTS: true,
})
assert.NoError(t, err)
o := oracleInterface.(*pdOracle)
defer o.Close()

now := time.Now()

mockTS := func(beforeNow time.Duration) uint64 {
Expand Down Expand Up @@ -339,3 +341,158 @@ func TestNextUpdateTSInterval(t *testing.T) {
assert.Equal(t, minAllowedAdaptiveUpdateTSInterval/2, o.nextUpdateInterval(now, 0))
assert.Equal(t, adaptiveUpdateTSIntervalStateUnadjustable, o.adaptiveUpdateIntervalState.state)
}

func TestValidateSnapshotReadTS(t *testing.T) {
pdClient := MockPdClient{}
o, err := NewPdOracle(&pdClient, &PDOracleOptions{
UpdateInterval: time.Second * 2,
})
assert.NoError(t, err)
defer o.Close()

ctx := context.Background()
opt := &oracle.Option{TxnScope: oracle.GlobalTxnScope}
ts, err := o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
assert.GreaterOrEqual(t, ts, uint64(1))

err = o.ValidateSnapshotReadTS(ctx, 1, opt)
assert.NoError(t, err)
ts, err = o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
// The readTS exceeds the latest ts, so it first fails the check with the low resolution ts. Then it fallbacks to
// the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass.
err = o.ValidateSnapshotReadTS(ctx, ts+1, opt)
assert.NoError(t, err)
// It can't pass if the readTS is newer than previous ts + 2.
ts, err = o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
err = o.ValidateSnapshotReadTS(ctx, ts+2, opt)
assert.Error(t, err)

// Simulate other PD clients requests a timestamp.
ts, err = o.GetTimestamp(ctx, opt)
assert.NoError(t, err)
pdClient.logicalTimestamp.Add(2)
err = o.ValidateSnapshotReadTS(ctx, ts+3, opt)
assert.NoError(t, err)
}

type MockPDClientWithPause struct {
MockPdClient
mu sync.Mutex
}

func (c *MockPDClientWithPause) GetTS(ctx context.Context) (int64, int64, error) {
c.mu.Lock()
defer c.mu.Unlock()
return c.MockPdClient.GetTS(ctx)
}

func (c *MockPDClientWithPause) Pause() {
c.mu.Lock()
}

func (c *MockPDClientWithPause) Resume() {
c.mu.Unlock()
}

func TestValidateSnapshotReadTSReusingGetTSResult(t *testing.T) {
pdClient := &MockPDClientWithPause{}
o, err := NewPdOracle(pdClient, &PDOracleOptions{
UpdateInterval: time.Second * 2,
NoUpdateTS: true,
})
assert.NoError(t, err)
defer o.Close()

asyncValidate := func(ctx context.Context, readTS uint64) chan error {
ch := make(chan error, 1)
go func() {
err := o.ValidateSnapshotReadTS(ctx, readTS, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
ch <- err
}()
return ch
}

noResult := func(ch chan error) {
select {
case <-ch:
assert.FailNow(t, "a ValidateSnapshotReadTS operation is not blocked while it's expected to be blocked")
default:
}
}

cancelIndices := []int{-1, -1, 0, 1}
for i, ts := range []uint64{100, 200, 300, 400} {
// Note: the ts is the result that the next GetTS will return. Any validation with readTS <= ts should pass, otherwise fail.

// We will cancel the cancelIndex-th validation call. This is for testing that canceling some of the calls
// doesn't affect other calls that are waiting
cancelIndex := cancelIndices[i]

pdClient.Pause()

results := make([]chan error, 0, 5)

ctx, cancel := context.WithCancel(context.Background())

getCtx := func(index int) context.Context {
if cancelIndex == index {
return ctx
} else {
return context.Background()
}
}

results = append(results, asyncValidate(getCtx(0), ts-2))
results = append(results, asyncValidate(getCtx(1), ts+2))
results = append(results, asyncValidate(getCtx(2), ts-1))
results = append(results, asyncValidate(getCtx(3), ts+1))
results = append(results, asyncValidate(getCtx(4), ts))

expectedSucceeds := []bool{true, false, true, false, true}

time.Sleep(time.Millisecond * 50)
for _, ch := range results {
noResult(ch)
}

cancel()

for i, ch := range results {
if i == cancelIndex {
select {
case err := <-ch:
assert.Errorf(t, err, "index: %v", i)
assert.Containsf(t, err.Error(), "context canceled", "index: %v", i)
case <-time.After(time.Second):
assert.FailNowf(t, "expected result to be ready but still blocked", "index: %v", i)
}
} else {
noResult(ch)
}
}

// ts will be the next ts returned to these validation calls.
pdClient.logicalTimestamp.Store(int64(ts - 1))
pdClient.Resume()
for i, ch := range results {
if i == cancelIndex {
continue
}

select {
case err = <-ch:
case <-time.After(time.Second):
assert.FailNowf(t, "expected result to be ready but still blocked", "index: %v", i)
}
if expectedSucceeds[i] {
assert.NoErrorf(t, err, "index: %v", i)
} else {
assert.Errorf(t, err, "index: %v", i)
assert.NotContainsf(t, err.Error(), "context canceled", "index: %v", i)
}
}
}
}

0 comments on commit 19c6546

Please sign in to comment.