From cc6b5a64653e8a7f9cdcb164afa8447dc06306ea Mon Sep 17 00:00:00 2001 From: Noble Mittal <62551163+beingnoble03@users.noreply.github.com> Date: Mon, 15 Apr 2024 20:54:32 +0530 Subject: [PATCH] test: Use testify require/assert instead of t.Fatal/Error in `go/vt/throttler` (#15703) Signed-off-by: Noble Mittal --- .../aggregated_interval_history_test.go | 7 +- go/vt/throttler/interval_history_test.go | 35 ++- go/vt/throttler/manager_test.go | 153 ++++++------- go/vt/throttler/memory_test.go | 209 +++++++++--------- go/vt/throttler/replication_lag_cache_test.go | 53 ++--- go/vt/throttler/result_test.go | 22 +- go/vt/throttler/thread_throttler_test.go | 13 +- go/vt/throttler/throttler_test.go | 126 ++++------- go/vt/throttler/throttlerlogz_test.go | 17 +- go/vt/throttler/throttlerz_test.go | 22 +- 10 files changed, 282 insertions(+), 375 deletions(-) diff --git a/go/vt/throttler/aggregated_interval_history_test.go b/go/vt/throttler/aggregated_interval_history_test.go index 6a77d57af07..f9348c10920 100644 --- a/go/vt/throttler/aggregated_interval_history_test.go +++ b/go/vt/throttler/aggregated_interval_history_test.go @@ -19,6 +19,8 @@ package throttler import ( "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestAggregatedIntervalHistory(t *testing.T) { @@ -26,7 +28,6 @@ func TestAggregatedIntervalHistory(t *testing.T) { h.addPerThread(0, record{sinceZero(0 * time.Second), 1000}) h.addPerThread(1, record{sinceZero(0 * time.Second), 2000}) - if got, want := h.average(sinceZero(250*time.Millisecond), sinceZero(750*time.Millisecond)), 3000.0; got != want { - t.Errorf("average(0.25s, 0.75s) across both threads = %v, want = %v", got, want) - } + got := h.average(sinceZero(250*time.Millisecond), sinceZero(750*time.Millisecond)) + assert.Equal(t, 3000.0, got) } diff --git a/go/vt/throttler/interval_history_test.go b/go/vt/throttler/interval_history_test.go index 7bad56e41c1..ec30b1c23c9 100644 --- a/go/vt/throttler/interval_history_test.go +++ b/go/vt/throttler/interval_history_test.go @@ -17,9 +17,11 @@ limitations under the License. package throttler import ( - "strings" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIntervalHistory_AverageIncludesPartialIntervals(t *testing.T) { @@ -33,9 +35,8 @@ func TestIntervalHistory_AverageIncludesPartialIntervals(t *testing.T) { h.add(record{sinceZero(3 * time.Second), 10000000}) // Rate within [1s, 2s) = 1000 and within [2s, 3s) = 2000 = average of 1500 want := 1500.0 - if got := h.average(sinceZero(1500*time.Millisecond), sinceZero(2500*time.Millisecond)); got != want { - t.Errorf("average(1.5s, 2.5s) = %v, want = %v", got, want) - } + got := h.average(sinceZero(1500*time.Millisecond), sinceZero(2500*time.Millisecond)) + assert.Equal(t, want, got) } func TestIntervalHistory_AverageRangeSmallerThanInterval(t *testing.T) { @@ -43,9 +44,8 @@ func TestIntervalHistory_AverageRangeSmallerThanInterval(t *testing.T) { h.add(record{sinceZero(0 * time.Second), 10000}) want := 10000.0 - if got := h.average(sinceZero(250*time.Millisecond), sinceZero(750*time.Millisecond)); got != want { - t.Errorf("average(0.25s, 0.75s) = %v, want = %v", got, want) - } + got := h.average(sinceZero(250*time.Millisecond), sinceZero(750*time.Millisecond)) + assert.Equal(t, want, got) } func TestIntervalHistory_GapsCountedAsZero(t *testing.T) { @@ -55,22 +55,17 @@ func TestIntervalHistory_GapsCountedAsZero(t *testing.T) { h.add(record{sinceZero(3 * time.Second), 1000}) want := 500.0 - if got := h.average(sinceZero(0*time.Second), sinceZero(4*time.Second)); got != want { - t.Errorf("average(0s, 4s) = %v, want = %v", got, want) - } + got := h.average(sinceZero(0*time.Second), sinceZero(4*time.Second)) + assert.Equal(t, want, got) } func TestIntervalHistory_AddNoDuplicateInterval(t *testing.T) { defer func() { r := recover() + require.NotNil(t, r, "add() did not panic") - if r == nil { - t.Fatal("add() did not panic") - } want := "BUG: cannot add record because it is already covered by a previous entry" - if !strings.Contains(r.(string), want) { - t.Fatalf("add() did panic for the wrong reason: got = %v, want = %v", r, want) - } + require.Contains(t, r, want, "add() did panic for the wrong reason") }() h := newIntervalHistory(10, 1*time.Second) @@ -82,14 +77,10 @@ func TestIntervalHistory_AddNoDuplicateInterval(t *testing.T) { func TestIntervalHistory_RecordDoesNotStartAtInterval(t *testing.T) { defer func() { r := recover() + require.NotNil(t, r, "add() did not panic") - if r == nil { - t.Fatal("add() did not panic") - } want := "BUG: cannot add record because it does not start at the beginning of the interval" - if !strings.Contains(r.(string), want) { - t.Fatalf("add() did panic for the wrong reason: got = %v, want = %v", r, want) - } + require.Contains(t, r, want, "add() did panic for the wrong reason") }() h := newIntervalHistory(1, 1*time.Second) diff --git a/go/vt/throttler/manager_test.go b/go/vt/throttler/manager_test.go index e6c3359b242..3d61d4d6b68 100644 --- a/go/vt/throttler/manager_test.go +++ b/go/vt/throttler/manager_test.go @@ -20,10 +20,12 @@ import ( "fmt" "reflect" "sort" - "strings" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + throttlerdatapb "vitess.io/vitess/go/vt/proto/throttlerdata" ) @@ -60,12 +62,11 @@ func (f *managerTestFixture) tearDown() { func TestManager_Registration(t *testing.T) { m := newManager() t1, err := newThrottler(m, "t1", "TPS", 1 /* threadCount */, MaxRateModuleDisabled, ReplicationLagModuleDisabled, time.Now) - if err != nil { - t.Fatal(err) - } - if err := m.registerThrottler("t1", t1); err == nil { - t.Fatalf("manager should not accept a duplicate registration of a throttler: %v", err) - } + require.NoError(t, err) + + err = m.registerThrottler("t1", t1) + require.Error(t, err, "manager should not accept a duplicate registration of a throttler") + t1.Close() // Unregistering an unregistered throttler should log an error. @@ -81,18 +82,16 @@ func TestManager_SetMaxRate(t *testing.T) { // Test SetMaxRate(). want := []string{"t1", "t2"} - if got := f.m.SetMaxRate(23); !reflect.DeepEqual(got, want) { - t.Errorf("manager did not set the rate on all throttlers. got = %v, want = %v", got, want) - } + got := f.m.SetMaxRate(23) + assert.Equal(t, want, got, "manager did not set the rate on all throttlers") // Test MaxRates(). wantRates := map[string]int64{ "t1": 23, "t2": 23, } - if gotRates := f.m.MaxRates(); !reflect.DeepEqual(gotRates, wantRates) { - t.Errorf("manager did not set the rate on all throttlers. got = %v, want = %v", gotRates, wantRates) - } + gotRates := f.m.MaxRates() + assert.Equal(t, wantRates, gotRates, "manager did not set the rate on all throttlers") } func TestManager_GetConfiguration(t *testing.T) { @@ -108,24 +107,16 @@ func TestManager_GetConfiguration(t *testing.T) { "t2": defaultMaxReplicationLagModuleConfig.Clone().Configuration, } got, err := f.m.GetConfiguration("" /* all */) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, want) { - t.Errorf("manager did not return the correct initial config for all throttlers. got = %v, want = %v", got, want) - } + require.NoError(t, err) + assert.Equal(t, want, got, "manager did not return the correct initial config for all throttlers") // Test GetConfiguration() when a specific throttler is requested. wantT2 := map[string]*throttlerdatapb.Configuration{ "t2": defaultMaxReplicationLagModuleConfig.Clone().Configuration, } gotT2, err := f.m.GetConfiguration("t2") - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(gotT2, wantT2) { - t.Errorf("manager did not return the correct initial config for throttler: %v got = %v, want = %v", "t2", gotT2, wantT2) - } + require.NoError(t, err) + assert.Equal(t, wantT2, gotT2, "manager did not return the correct initial config for throttler: t2") // Now change the config and then reset it back. newConfig := &throttlerdatapb.Configuration{ @@ -133,42 +124,35 @@ func TestManager_GetConfiguration(t *testing.T) { IgnoreNSlowestReplicas: defaultIgnoreNSlowestReplicas + 1, } allNames, err := f.m.UpdateConfiguration("", newConfig, false /* copyZeroValues */) - if err != nil { - t.Fatal(err) - } - // Verify it was changed. - if err := checkConfig(f.m, []string{"t1", "t2"}, allNames, defaultTargetLag+1, defaultIgnoreNSlowestReplicas+1); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + + err = checkConfig(f.m, []string{"t1", "t2"}, allNames, defaultTargetLag+1, defaultIgnoreNSlowestReplicas+1) + require.NoError(t, err) + // Reset only "t2". - if names, err := f.m.ResetConfiguration("t2"); err != nil || !reflect.DeepEqual(names, []string{"t2"}) { - t.Fatalf("Reset failed or returned wrong throttler names: %v err: %v", names, err) - } + names, err := f.m.ResetConfiguration("t2") + require.NoError(t, err) + assert.Equal(t, []string{"t2"}, names, "Reset failed or returned wrong throttler names") + gotT2AfterReset, err := f.m.GetConfiguration("t2") - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(gotT2AfterReset, wantT2) { - t.Errorf("manager did not return the correct initial config for throttler %v after reset: got = %v, want = %v", "t2", gotT2AfterReset, wantT2) - } + require.NoError(t, err) + assert.Equal(t, wantT2, gotT2AfterReset, "manager did not return the correct initial config for throttler t2 after reset") + // Reset all throttlers. - if names, err := f.m.ResetConfiguration(""); err != nil || !reflect.DeepEqual(names, []string{"t1", "t2"}) { - t.Fatalf("Reset failed or returned wrong throttler names: %v err: %v", names, err) - } + + names, err = f.m.ResetConfiguration("") + require.NoError(t, err) + assert.Equal(t, []string{"t1", "t2"}, names, "Reset failed or returned wrong throttler names") + gotAfterReset, err := f.m.GetConfiguration("") - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(gotAfterReset, want) { - t.Errorf("manager did not return the correct initial config for all throttlers after reset. got = %v, want = %v", got, want) - } + require.NoError(t, err) + assert.Equal(t, want, gotAfterReset, "manager did not return the correct initial config for all throttlers after reset") } func TestManager_UpdateConfiguration_Error(t *testing.T) { f := &managerTestFixture{} - if err := f.setUp(); err != nil { - t.Fatal(err) - } + err := f.setUp() + require.NoError(t, err) defer f.tearDown() // Check that errors from Verify() are correctly propagated. @@ -176,21 +160,15 @@ func TestManager_UpdateConfiguration_Error(t *testing.T) { // max < 2 is not allowed. MaxReplicationLagSec: 1, } - if _, err := f.m.UpdateConfiguration("t2", invalidConfig, false /* copyZeroValues */); err == nil { - t.Fatal("expected error but got nil") - } else { - want := "max_replication_lag_sec must be >= 2" - if !strings.Contains(err.Error(), want) { - t.Fatalf("received wrong error. got = %v, want contains = %v", err, want) - } - } + _, err = f.m.UpdateConfiguration("t2", invalidConfig, false /* copyZeroValues */) + wantErr := "max_replication_lag_sec must be >= 2" + require.ErrorContains(t, err, wantErr) } func TestManager_UpdateConfiguration_Partial(t *testing.T) { f := &managerTestFixture{} - if err := f.setUp(); err != nil { - t.Fatal(err) - } + err := f.setUp() + require.NoError(t, err) defer f.tearDown() // Verify that a partial update only updates that one field. @@ -199,47 +177,40 @@ func TestManager_UpdateConfiguration_Partial(t *testing.T) { IgnoreNSlowestReplicas: wantIgnoreNSlowestReplicas, } names, err := f.m.UpdateConfiguration("t2", partialConfig, false /* copyZeroValues */) - if err != nil { - t.Fatal(err) - } - if err := checkConfig(f.m, []string{"t2"}, names, defaultTargetLag, wantIgnoreNSlowestReplicas); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + + err = checkConfig(f.m, []string{"t2"}, names, defaultTargetLag, wantIgnoreNSlowestReplicas) + require.NoError(t, err) + // Repeat test for all throttlers. allNames, err := f.m.UpdateConfiguration("" /* all */, partialConfig, false /* copyZeroValues */) - if err != nil { - t.Fatal(err) - } - if err := checkConfig(f.m, []string{"t1", "t2"}, allNames, defaultTargetLag, wantIgnoreNSlowestReplicas); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + + err = checkConfig(f.m, []string{"t1", "t2"}, allNames, defaultTargetLag, wantIgnoreNSlowestReplicas) + require.NoError(t, err) } func TestManager_UpdateConfiguration_ZeroValues(t *testing.T) { f := &managerTestFixture{} - if err := f.setUp(); err != nil { - t.Fatal(err) - } + err := f.setUp() + require.NoError(t, err) defer f.tearDown() // Test the explicit copy of zero values. zeroValueConfig := defaultMaxReplicationLagModuleConfig.Configuration.CloneVT() zeroValueConfig.IgnoreNSlowestReplicas = 0 names, err := f.m.UpdateConfiguration("t2", zeroValueConfig, true /* copyZeroValues */) - if err != nil { - t.Fatal(err) - } - if err := checkConfig(f.m, []string{"t2"}, names, defaultTargetLag, 0); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + + err = checkConfig(f.m, []string{"t2"}, names, defaultTargetLag, 0) + require.NoError(t, err) + // Repeat test for all throttlers. allNames, err := f.m.UpdateConfiguration("" /* all */, zeroValueConfig, true /* copyZeroValues */) - if err != nil { - t.Fatal(err) - } - if err := checkConfig(f.m, []string{"t1", "t2"}, allNames, defaultTargetLag, 0); err != nil { - t.Fatal(err) - } + require.NoError(t, err) + + err = checkConfig(f.m, []string{"t1", "t2"}, allNames, defaultTargetLag, 0) + require.NoError(t, err) } func checkConfig(m *managerImpl, throttlers []string, updatedThrottlers []string, targetLag int64, ignoreNSlowestReplicas int32) error { diff --git a/go/vt/throttler/memory_test.go b/go/vt/throttler/memory_test.go index 899e175672a..7dcc13301f7 100644 --- a/go/vt/throttler/memory_test.go +++ b/go/vt/throttler/memory_test.go @@ -20,168 +20,157 @@ import ( "testing" "time" - "vitess.io/vitess/go/vt/log" + "github.com/stretchr/testify/require" ) func TestMemory(t *testing.T) { m := newMemory(5, 1*time.Second, 0.10) // Add several good rates. - if err := m.markGood(201); err != nil { - log.Errorf("m.markGood(201) failed :%v ", err) - } + err := m.markGood(201) + require.NoError(t, err) want200 := int64(200) - if got := m.highestGood(); got != want200 { - t.Fatalf("memory with one good entry: got = %v, want = %v", got, want200) - } + got := m.highestGood() + require.Equal(t, want200, got, "memory with one good entry") - //log error - if err := m.markGood(101); err != nil { - log.Errorf("m.markGood(101) failed :%v ", err) - } + err = m.markGood(101) + require.NoError(t, err) - if got := m.highestGood(); got != want200 { - t.Fatalf("wrong order within memory: got = %v, want = %v", got, want200) - } + got = m.highestGood() + require.Equal(t, want200, got, "wrong order within memory") - //log error - if err := m.markGood(301); err != nil { - log.Errorf(" m.markGood(301) failed :%v ", err) - } + err = m.markGood(301) + require.NoError(t, err) want300 := int64(300) - if got := m.highestGood(); got != want300 { - t.Fatalf("wrong order within memory: got = %v, want = %v", got, want300) - } - m.markGood(306) + got = m.highestGood() + require.Equal(t, want300, got, "wrong order within memory") + + err = m.markGood(306) + require.NoError(t, err) + want305 := int64(305) - if got := m.highestGood(); got != want305 { - t.Fatalf("wrong order within memory: got = %v, want = %v", got, want305) - } + got = m.highestGood() + require.Equal(t, want305, got, "wrong order within memory") // 300 and 305 will turn from good to bad. - if got := m.lowestBad(); got != 0 { - t.Fatalf("lowestBad should return zero value when no bad rate is recorded yet: got = %v", got) - } - - //log error - if err := m.markBad(300, sinceZero(0)); err != nil { - log.Errorf(" m.markBad(300, sinceZero(0)) failed :%v ", err) - } - - if got, want := m.lowestBad(), want300; got != want { - t.Fatalf("bad rate was not recorded: got = %v, want = %v", got, want) - } - if got := m.highestGood(); got != want200 { - t.Fatalf("new lower bad rate did not invalidate previous good rates: got = %v, want = %v", got, want200) - } - - //log error - if err := m.markBad(311, sinceZero(0)); err != nil { - log.Errorf(" m.markBad(311, sinceZero(0)) failed :%v ", err) - } - - if got := m.lowestBad(); got != want300 { - t.Fatalf("bad rates higher than the current one should be ignored: got = %v, want = %v", got, want300) - } + got = m.lowestBad() + require.Equal(t, int64(0), got, "lowestBad should return zero value when no bad rate is recorded yet") + + err = m.markBad(300, sinceZero(0)) + require.NoError(t, err) + + got = m.lowestBad() + require.Equal(t, want300, got, "bad rate was not recorded") + + got = m.highestGood() + require.Equal(t, want200, got, "new lower bad rate did not invalidate previous good rates") + + err = m.markBad(311, sinceZero(0)) + require.NoError(t, err) + + got = m.lowestBad() + require.Equal(t, want300, got, "bad rates higher than the current one should be ignored") // a good 601 will be ignored because the first bad is at 300. - if err := m.markGood(601); err == nil { - t.Fatal("good rates cannot go beyond the lowest bad rate: should have returned an error") - } - if got := m.lowestBad(); got != want300 { - t.Fatalf("good rates cannot go beyond the lowest bad rate: got = %v, want = %v", got, want300) - } - if got := m.highestGood(); got != want200 { - t.Fatalf("good rates beyond the lowest bad rate must be ignored: got = %v, want = %v", got, want200) - } + err = m.markGood(601) + require.Error(t, err, "good rates cannot go beyond the lowest bad rate") + + got = m.lowestBad() + require.Equal(t, want300, got, "good rates cannot go beyond the lowest bad rate") + + got = m.highestGood() + require.Equal(t, want200, got, "good rates beyond the lowest bad rate must be ignored") // 199 will be rounded up to 200. - err := m.markBad(199, sinceZero(0)) + err = m.markBad(199, sinceZero(0)) + require.NoError(t, err) - if err != nil { - t.Fatalf(" m.markBad(199, sinceZero(0)) failed :%v ", err) - } + got = m.lowestBad() + require.Equal(t, want200, got, "bad rate was not updated") - if got := m.lowestBad(); got != want200 { - t.Fatalf("bad rate was not updated: got = %v, want = %v", got, want200) - } want100 := int64(100) - if got := m.highestGood(); got != want100 { - t.Fatalf("previous highest good rate was not marked as bad: got = %v, want = %v", got, want100) - } + got = m.highestGood() + require.Equal(t, want100, got, "previous highest good rate was not marked as bad") } func TestMemory_markDownIgnoresDrasticBadValues(t *testing.T) { m := newMemory(1, 1*time.Second, 0.10) good := int64(1000) bad := int64(1001) - m.markGood(good) - m.markBad(bad, sinceZero(0)) - if got := m.highestGood(); got != good { - t.Fatalf("good rate was not correctly inserted: got = %v, want = %v", got, good) - } - if got := m.lowestBad(); got != bad { - t.Fatalf("bad rate was not correctly inserted: got = %v, want = %v", got, bad) - } - - if err := m.markBad(500, sinceZero(0)); err == nil { - t.Fatal("bad rate should have been ignored and an error should have been returned") - } - if got := m.highestGood(); got != good { - t.Fatalf("bad rate should have been ignored: got = %v, want = %v", got, good) - } - if got := m.lowestBad(); got != bad { - t.Fatalf("bad rate should have been ignored: got = %v, want = %v", got, bad) - } + + err := m.markGood(good) + require.NoError(t, err) + + err = m.markBad(bad, sinceZero(0)) + require.NoError(t, err) + + got := m.highestGood() + require.Equal(t, good, got, "good rate was not correctly inserted") + + got = m.lowestBad() + require.Equal(t, bad, got, "bad rate was not correctly inserted") + + err = m.markBad(500, sinceZero(0)) + require.Error(t, err, "bad rate should have been ignored and an error should have been returned") + + got = m.highestGood() + require.Equal(t, good, got, "bad rate should have been ignored") + + got = m.lowestBad() + require.Equal(t, bad, got, "bad rate should have been ignored") } func TestMemory_Aging(t *testing.T) { m := newMemory(1, 2*time.Second, 0.10) - m.markBad(100, sinceZero(0)) - if got, want := m.lowestBad(), int64(100); got != want { - t.Fatalf("bad rate was not correctly inserted: got = %v, want = %v", got, want) - } + err := m.markBad(100, sinceZero(0)) + require.NoError(t, err) + + got := m.lowestBad() + require.Equal(t, int64(100), got, "bad rate was not correctly inserted") // Bad rate successfully ages by 10%. m.ageBadRate(sinceZero(2 * time.Second)) - if got, want := m.lowestBad(), int64(110); got != want { - t.Fatalf("bad rate should have been increased due to its age: got = %v, want = %v", got, want) - } + + got = m.lowestBad() + require.Equal(t, int64(110), got, "bad rate should have been increased due to its age") // A recent aging resets the age timer. m.ageBadRate(sinceZero(2 * time.Second)) - if got, want := m.lowestBad(), int64(110); got != want { - t.Fatalf("a bad rate should not age again until the age is up again: got = %v, want = %v", got, want) - } + got = m.lowestBad() + require.Equal(t, int64(110), got, "a bad rate should not age again until the age is up again") // The age timer will be reset if the bad rate changes. - m.markBad(100, sinceZero(3*time.Second)) + err = m.markBad(100, sinceZero(3*time.Second)) + require.NoError(t, err) + m.ageBadRate(sinceZero(4 * time.Second)) - if got, want := m.lowestBad(), int64(100); got != want { - t.Fatalf("bad rate must not age yet: got = %v, want = %v", got, want) - } + + got = m.lowestBad() + require.Equal(t, int64(100), got, "bad rate must not age yet") // The age timer won't be reset when the rate stays the same. - m.markBad(100, sinceZero(4*time.Second)) + err = m.markBad(100, sinceZero(4*time.Second)) + require.NoError(t, err) + m.ageBadRate(sinceZero(5 * time.Second)) - if got, want := m.lowestBad(), int64(110); got != want { - t.Fatalf("bad rate should have aged again: got = %v, want = %v", got, want) - } + + got = m.lowestBad() + require.Equal(t, int64(110), got, "bad rate should have aged again") // Update the aging config. It will be effective immediately. m.updateAgingConfiguration(1*time.Second, 0.05) m.ageBadRate(sinceZero(6 * time.Second)) - if got, want := m.lowestBad(), int64(115); got != want { - t.Fatalf("bad rate should have aged after the configuration update: got = %v, want = %v", got, want) - } + + got = m.lowestBad() + require.Equal(t, int64(115), got, "bad rate should have aged after the configuration update") // If the new bad rate is not higher, it should increase by the memory granularity at least. m.markBad(5, sinceZero(10*time.Second)) m.ageBadRate(sinceZero(11 * time.Second)) - if got, want := m.lowestBad(), int64(5+memoryGranularity); got != want { - t.Fatalf("bad rate should have aged after the configuration update: got = %v, want = %v", got, want) - } + + got = m.lowestBad() + require.Equal(t, int64(5+memoryGranularity), got, "bad rate should have aged after the configuration update") } diff --git a/go/vt/throttler/replication_lag_cache_test.go b/go/vt/throttler/replication_lag_cache_test.go index 312f97e1999..135c0f03956 100644 --- a/go/vt/throttler/replication_lag_cache_test.go +++ b/go/vt/throttler/replication_lag_cache_test.go @@ -20,6 +20,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/vt/discovery" ) @@ -33,44 +35,39 @@ func TestReplicationLagCache(t *testing.T) { // If there is no entry yet, a zero struct is returned. zeroEntry := c.atOrAfter(r1Key, sinceZero(0*time.Second)) - if !zeroEntry.isZero() { - t.Fatalf("atOrAfter() should have returned a zero entry but did not: %v", zeroEntry) - } + require.True(t, zeroEntry.isZero(), "atOrAfter() should have returned a zero entry") // First entry at 1s. c.add(lagRecord(sinceZero(1*time.Second), r1, 1)) - if got, want := c.latest(r1Key).time, sinceZero(1*time.Second); got != want { - t.Fatalf("latest(r1) = %v, want = %v", got, want) - } + got, want := c.latest(r1Key).time, sinceZero(1*time.Second) + require.Equal(t, want, got) // Second entry at 2s makes the cache full. c.add(lagRecord(sinceZero(2*time.Second), r1, 2)) - if got, want := c.latest(r1Key).time, sinceZero(2*time.Second); got != want { - t.Fatalf("latest(r1) = %v, want = %v", got, want) - } - if got, want := c.atOrAfter(r1Key, sinceZero(1*time.Second)).time, sinceZero(1*time.Second); got != want { - t.Fatalf("atOrAfter(r1) = %v, want = %v", got, want) - } + got, want = c.latest(r1Key).time, sinceZero(2*time.Second) + require.Equal(t, want, got) + + got, want = c.atOrAfter(r1Key, sinceZero(1*time.Second)).time, sinceZero(1*time.Second) + require.Equal(t, want, got) // Third entry at 3s evicts the 1s entry. c.add(lagRecord(sinceZero(3*time.Second), r1, 3)) - if got, want := c.latest(r1Key).time, sinceZero(3*time.Second); got != want { - t.Fatalf("latest(r1) = %v, want = %v", got, want) - } + got, want = c.latest(r1Key).time, sinceZero(3*time.Second) + require.Equal(t, want, got) + // Requesting an entry at 1s or after gets us the entry for 2s. - if got, want := c.atOrAfter(r1Key, sinceZero(1*time.Second)).time, sinceZero(2*time.Second); got != want { - t.Fatalf("atOrAfter(r1) = %v, want = %v", got, want) - } + got, want = c.atOrAfter(r1Key, sinceZero(1*time.Second)).time, sinceZero(2*time.Second) + require.Equal(t, want, got) // Wrap around one more time. Entries at 4s and 5s should be left. c.add(lagRecord(sinceZero(4*time.Second), r1, 4)) c.add(lagRecord(sinceZero(5*time.Second), r1, 5)) - if got, want := c.latest(r1Key).time, sinceZero(5*time.Second); got != want { - t.Fatalf("latest(r1) = %v, want = %v", got, want) - } - if got, want := c.atOrAfter(r1Key, sinceZero(1*time.Second)).time, sinceZero(4*time.Second); got != want { - t.Fatalf("atOrAfter(r1) = %v, want = %v", got, want) - } + + got, want = c.latest(r1Key).time, sinceZero(5*time.Second) + require.Equal(t, want, got) + + got, want = c.atOrAfter(r1Key, sinceZero(1*time.Second)).time, sinceZero(4*time.Second) + require.Equal(t, want, got) } func TestReplicationLagCache_SortByLag(t *testing.T) { @@ -80,14 +77,10 @@ func TestReplicationLagCache_SortByLag(t *testing.T) { c.add(lagRecord(sinceZero(1*time.Second), r1, 30)) c.sortByLag(1 /* ignoreNSlowestReplicas */, 30 /* minimumReplicationLag */) - if c.slowReplicas[r1Key] { - t.Fatal("the only replica tracked should not get ignored") - } + require.False(t, c.slowReplicas[r1Key], "the only replica tracked should not get ignored") c.add(lagRecord(sinceZero(1*time.Second), r2, 1)) c.sortByLag(1 /* ignoreNSlowestReplicas */, 1 /* minimumReplicationLag */) - if !c.slowReplicas[r1Key] { - t.Fatal("r1 should be tracked as a slow replica") - } + require.True(t, c.slowReplicas[r1Key], "r1 should be tracked as a slow replica") } diff --git a/go/vt/throttler/result_test.go b/go/vt/throttler/result_test.go index 9efc7df9412..8cc5357ef7b 100644 --- a/go/vt/throttler/result_test.go +++ b/go/vt/throttler/result_test.go @@ -17,9 +17,10 @@ limitations under the License. package throttler import ( - "reflect" "testing" "time" + + "github.com/stretchr/testify/require" ) var ( @@ -127,9 +128,7 @@ reason: emergency state decreased the rate`, for _, tc := range testcases { got := tc.r.String() - if got != tc.want { - t.Fatalf("record.String() = %v, want = %v for full record: %#v", got, tc.want, tc.r) - } + require.Equal(t, tc.want, got) } } @@ -143,19 +142,16 @@ func TestResultRing(t *testing.T) { // Use the ring partially. rr.add(r1) - if got, want := rr.latestValues(), []result{r1}; !reflect.DeepEqual(got, want) { - t.Fatalf("items not correctly added to resultRing. got = %v, want = %v", got, want) - } + got, want := rr.latestValues(), []result{r1} + require.Equal(t, want, got, "items not correctly added to resultRing") // Use it fully. rr.add(r2) - if got, want := rr.latestValues(), []result{r2, r1}; !reflect.DeepEqual(got, want) { - t.Fatalf("items not correctly added to resultRing. got = %v, want = %v", got, want) - } + got, want = rr.latestValues(), []result{r2, r1} + require.Equal(t, want, got, "items not correctly added to resultRing") // Let it wrap. rr.add(r3) - if got, want := rr.latestValues(), []result{r3, r2}; !reflect.DeepEqual(got, want) { - t.Fatalf("resultRing did not wrap correctly. got = %v, want = %v", got, want) - } + got, want = rr.latestValues(), []result{r3, r2} + require.Equal(t, want, got, "resultRing did not wrap correctly") } diff --git a/go/vt/throttler/thread_throttler_test.go b/go/vt/throttler/thread_throttler_test.go index 7cb27e76487..2f97a66c6bc 100644 --- a/go/vt/throttler/thread_throttler_test.go +++ b/go/vt/throttler/thread_throttler_test.go @@ -19,6 +19,8 @@ package throttler import ( "testing" "time" + + "github.com/stretchr/testify/require" ) func TestThrottle_NoBurst(t *testing.T) { @@ -28,11 +30,10 @@ func TestThrottle_NoBurst(t *testing.T) { // 1. This means that in any time interval of length t seconds, the throttler should // not allow more than floor(2*t+1) requests. For example, in the interval [1500ms, 1501ms], of // length 1ms, we shouldn't be able to send more than floor(2*10^-3+1)=1 requests. - if gotBackoff := tt.throttle(sinceZero(1500 * time.Millisecond)); gotBackoff != NotThrottled { - t.Fatalf("throttler should not have throttled us: backoff = %v", gotBackoff) - } + gotBackoff := tt.throttle(sinceZero(1500 * time.Millisecond)) + require.Equal(t, NotThrottled, gotBackoff, "throttler should not have throttled us") + wantBackoff := 499 * time.Millisecond - if gotBackoff := tt.throttle(sinceZero(1501 * time.Millisecond)); gotBackoff != wantBackoff { - t.Fatalf("throttler should have throttled us. got = %v, want = %v", gotBackoff, wantBackoff) - } + gotBackoff = tt.throttle(sinceZero(1501 * time.Millisecond)) + require.Equal(t, wantBackoff, gotBackoff, "throttler should have throttled us") } diff --git a/go/vt/throttler/throttler_test.go b/go/vt/throttler/throttler_test.go index 0bb0ed0387a..b33bb2ca255 100644 --- a/go/vt/throttler/throttler_test.go +++ b/go/vt/throttler/throttler_test.go @@ -18,9 +18,10 @@ package throttler import ( "runtime" - "strings" "testing" "time" + + "github.com/stretchr/testify/require" ) // The main purpose of the benchmarks below is to demonstrate the functionality @@ -176,35 +177,30 @@ func TestThrottle(t *testing.T) { // 2 QPS should divide the current second into two chunks of 500 ms: // a) [1s, 1.5s), b) [1.5s, 2s) // First call goes through since the chunk is not "used" yet. - if gotBackoff := throttler.Throttle(0); gotBackoff != NotThrottled { - t.Fatalf("throttler should not have throttled us: backoff = %v", gotBackoff) - } + gotBackoff := throttler.Throttle(0) + require.Equal(t, NotThrottled, gotBackoff, "throttler should not have throttled us") // Next call should tell us to backoff until we reach the second chunk. fc.setNow(1000 * time.Millisecond) wantBackoff := 500 * time.Millisecond - if gotBackoff := throttler.Throttle(0); gotBackoff != wantBackoff { - t.Fatalf("throttler should have throttled us. got = %v, want = %v", gotBackoff, wantBackoff) - } + gotBackoff = throttler.Throttle(0) + require.Equal(t, wantBackoff, gotBackoff, "throttler should have throttled us") // Some time elpased, but we are still in the first chunk and must backoff. fc.setNow(1111 * time.Millisecond) wantBackoff2 := 389 * time.Millisecond - if gotBackoff := throttler.Throttle(0); gotBackoff != wantBackoff2 { - t.Fatalf("throttler should have still throttled us. got = %v, want = %v", gotBackoff, wantBackoff2) - } + gotBackoff = throttler.Throttle(0) + require.Equal(t, wantBackoff2, gotBackoff, "throttler should have still throttled us") // Enough time elapsed that we are in the second chunk now. fc.setNow(1500 * time.Millisecond) - if gotBackoff := throttler.Throttle(0); gotBackoff != NotThrottled { - t.Fatalf("throttler should not have throttled us: backoff = %v", gotBackoff) - } + gotBackoff = throttler.Throttle(0) + require.Equal(t, NotThrottled, gotBackoff, "throttler should not have throttled us") // We're in the third chunk and are allowed to issue the third request. fc.setNow(2001 * time.Millisecond) - if gotBackoff := throttler.Throttle(0); gotBackoff != NotThrottled { - t.Fatalf("throttler should not have throttled us: backoff = %v", gotBackoff) - } + gotBackoff = throttler.Throttle(0) + require.Equal(t, NotThrottled, gotBackoff, "throttler should not have throttled us") } func TestThrottle_RateRemainderIsDistributedAcrossThreads(t *testing.T) { @@ -216,9 +212,8 @@ func TestThrottle_RateRemainderIsDistributedAcrossThreads(t *testing.T) { fc.setNow(1000 * time.Millisecond) // Out of 5 QPS, each thread gets 1 and two threads get 1 query extra. for threadID := 0; threadID < 2; threadID++ { - if gotBackoff := throttler.Throttle(threadID); gotBackoff != NotThrottled { - t.Fatalf("throttler should not have throttled thread %d: backoff = %v", threadID, gotBackoff) - } + gotBackoff := throttler.Throttle(threadID) + require.Equalf(t, NotThrottled, gotBackoff, "throttler should not have throttled thread %d", threadID) } fc.setNow(1500 * time.Millisecond) @@ -229,21 +224,18 @@ func TestThrottle_RateRemainderIsDistributedAcrossThreads(t *testing.T) { threadsWithMoreThanOneQPS++ } else { wantBackoff := 500 * time.Millisecond - if gotBackoff != wantBackoff { - t.Fatalf("throttler did throttle us with the wrong backoff time. got = %v, want = %v", gotBackoff, wantBackoff) - } + require.Equal(t, wantBackoff, gotBackoff, "throttler did throttle us with the wrong backoff time") } } if want := 2; threadsWithMoreThanOneQPS != want { - t.Fatalf("wrong number of threads were throttled: %v != %v", threadsWithMoreThanOneQPS, want) + require.Equal(t, want, threadsWithMoreThanOneQPS, "wrong number of threads were throttled") } // Now, all threads are throttled. for threadID := 0; threadID < 2; threadID++ { wantBackoff := 500 * time.Millisecond - if gotBackoff := throttler.Throttle(threadID); gotBackoff != wantBackoff { - t.Fatalf("throttler should have throttled thread %d. got = %v, want = %v", threadID, gotBackoff, wantBackoff) - } + gotBackoff := throttler.Throttle(threadID) + require.Equalf(t, wantBackoff, gotBackoff, "throttler should have throttled thread %d", threadID) } } @@ -256,16 +248,14 @@ func TestThreadFinished(t *testing.T) { // [1000ms, 2000ms): Each thread consumes their 1 QPS. fc.setNow(1000 * time.Millisecond) for threadID := 0; threadID < 2; threadID++ { - if gotBackoff := throttler.Throttle(threadID); gotBackoff != NotThrottled { - t.Fatalf("throttler should not have throttled thread %d: backoff = %v", threadID, gotBackoff) - } + gotBackoff := throttler.Throttle(threadID) + require.Equalf(t, NotThrottled, gotBackoff, "throttler should not have throttled thread %d", threadID) } // Now they would be throttled. wantBackoff := 1000 * time.Millisecond for threadID := 0; threadID < 2; threadID++ { - if gotBackoff := throttler.Throttle(threadID); gotBackoff != wantBackoff { - t.Fatalf("throttler should have throttled thread %d. got = %v, want = %v", threadID, gotBackoff, wantBackoff) - } + gotBackoff := throttler.Throttle(threadID) + require.Equalf(t, wantBackoff, gotBackoff, "throttler should have throttled thread %d", threadID) } // [2000ms, 3000ms): One thread finishes, other one gets remaining 1 QPS extra. @@ -288,29 +278,23 @@ func TestThreadFinished(t *testing.T) { } // Consume 2 QPS. - if gotBackoff := throttler.Throttle(0); gotBackoff != NotThrottled { - t.Fatalf("throttler should not have throttled us: backoff = %v", gotBackoff) - } + gotBackoff := throttler.Throttle(0) + require.Equal(t, NotThrottled, gotBackoff, "throttler should not have throttled us") + fc.setNow(2500 * time.Millisecond) - if gotBackoff := throttler.Throttle(0); gotBackoff != NotThrottled { - t.Fatalf("throttler should not have throttled us: backoff = %v", gotBackoff) - } + gotBackoff = throttler.Throttle(0) + require.Equal(t, NotThrottled, gotBackoff, "throttler should not have throttled us") // 2 QPS are consumed. Thread 0 should be throttled now. wantBackoff2 := 500 * time.Millisecond - if gotBackoff := throttler.Throttle(0); gotBackoff != wantBackoff2 { - t.Fatalf("throttler should have throttled us. got = %v, want = %v", gotBackoff, wantBackoff2) - } + gotBackoff = throttler.Throttle(0) + require.Equal(t, wantBackoff2, gotBackoff, "throttler should have throttled us") // Throttle() from a finished thread will panic. defer func() { msg := recover() - if msg == nil { - t.Fatal("Throttle() from a thread which called ThreadFinished() should panic") - } - if !strings.Contains(msg.(string), "already finished") { - t.Fatalf("Throttle() after ThreadFinished() panic'd for wrong reason: %v", msg) - } + require.NotNil(t, msg) + require.Contains(t, msg, "already finished", "Throttle() after ThreadFinished() panic'd for wrong reason") }() throttler.Throttle(1) } @@ -326,19 +310,18 @@ func TestThrottle_MaxRateIsZero(t *testing.T) { fc.setNow(1000 * time.Millisecond) wantBackoff := 1000 * time.Millisecond - if gotBackoff := throttler.Throttle(0); gotBackoff != wantBackoff { - t.Fatalf("throttler should have throttled us. got = %v, want = %v", gotBackoff, wantBackoff) - } + gotBackoff := throttler.Throttle(0) + require.Equal(t, wantBackoff, gotBackoff, "throttler should have throttled us") + fc.setNow(1111 * time.Millisecond) wantBackoff2 := 1000 * time.Millisecond - if gotBackoff := throttler.Throttle(0); gotBackoff != wantBackoff2 { - t.Fatalf("throttler should have throttled us. got = %v, want = %v", gotBackoff, wantBackoff2) - } + gotBackoff = throttler.Throttle(0) + require.Equal(t, wantBackoff2, gotBackoff, "throttler should have throttled us") + fc.setNow(2000 * time.Millisecond) wantBackoff3 := 1000 * time.Millisecond - if gotBackoff := throttler.Throttle(0); gotBackoff != wantBackoff3 { - t.Fatalf("throttler should have throttled us. got = %v, want = %v", gotBackoff, wantBackoff3) - } + gotBackoff = throttler.Throttle(0) + require.Equal(t, wantBackoff3, gotBackoff, "throttler should have throttled us") } func TestThrottle_MaxRateDisabled(t *testing.T) { @@ -349,9 +332,8 @@ func TestThrottle_MaxRateDisabled(t *testing.T) { fc.setNow(1000 * time.Millisecond) // No QPS set. 10 requests in a row are fine. for i := 0; i < 10; i++ { - if gotBackoff := throttler.Throttle(0); gotBackoff != NotThrottled { - t.Fatalf("throttler should not have throttled us: request = %v, backoff = %v", i, gotBackoff) - } + gotBackoff := throttler.Throttle(0) + require.Equal(t, NotThrottled, gotBackoff, "throttler should not have throttled us") } } @@ -368,15 +350,13 @@ func TestThrottle_MaxRateLowerThanThreadCount(t *testing.T) { // must not starve. fc.setNow(1000 * time.Millisecond) for threadID := 0; threadID < 1; threadID++ { - if gotBackoff := throttler.Throttle(threadID); gotBackoff != NotThrottled { - t.Fatalf("throttler should not have throttled thread %d: backoff = %v", threadID, gotBackoff) - } + gotBackoff := throttler.Throttle(threadID) + require.Equalf(t, NotThrottled, gotBackoff, "throttler should not have throttled thread %d", threadID) } wantBackoff := 1000 * time.Millisecond for threadID := 0; threadID < 1; threadID++ { - if gotBackoff := throttler.Throttle(threadID); gotBackoff != wantBackoff { - t.Fatalf("throttler should have throttled thread %d: got = %v, want = %v", threadID, gotBackoff, wantBackoff) - } + gotBackoff := throttler.Throttle(threadID) + require.Equalf(t, wantBackoff, gotBackoff, "throttler should have throttled thread %d", threadID) } } @@ -400,12 +380,8 @@ func TestClose(t *testing.T) { defer func() { msg := recover() - if msg == nil { - t.Fatal("Throttle() after Close() should panic") - } - if !strings.Contains(msg.(string), "must not access closed Throttler") { - t.Fatalf("Throttle() after ThreadFinished() panic'd for wrong reason: %v", msg) - } + require.NotNil(t, msg) + require.Contains(t, msg, "must not access closed Throttler", "Throttle() after ThreadFinished() panic'd for wrong reason") }() throttler.Throttle(0) } @@ -417,12 +393,8 @@ func TestThreadFinished_SecondCallPanics(t *testing.T) { defer func() { msg := recover() - if msg == nil { - t.Fatal("Second ThreadFinished() after ThreadFinished() should panic") - } - if !strings.Contains(msg.(string), "already finished") { - t.Fatalf("ThreadFinished() after ThreadFinished() panic'd for wrong reason: %v", msg) - } + require.NotNil(t, msg) + require.Contains(t, msg, "already finished", "Throttle() after ThreadFinished() panic'd for wrong reason") }() throttler.ThreadFinished(0) } diff --git a/go/vt/throttler/throttlerlogz_test.go b/go/vt/throttler/throttlerlogz_test.go index 6fdb137577c..22927f3f201 100644 --- a/go/vt/throttler/throttlerlogz_test.go +++ b/go/vt/throttler/throttlerlogz_test.go @@ -19,8 +19,9 @@ package throttler import ( "net/http" "net/http/httptest" - "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestThrottlerlogzHandler_MissingSlash(t *testing.T) { @@ -30,9 +31,8 @@ func TestThrottlerlogzHandler_MissingSlash(t *testing.T) { throttlerlogzHandler(response, request, m) - if got, want := response.Body.String(), "invalid /throttlerlogz path"; !strings.Contains(got, want) { - t.Fatalf("/throttlerlogz without the slash does not work (the Go HTTP server does automatically redirect in practice though). got = %v, want = %v", got, want) - } + got := response.Body.String() + require.Contains(t, got, "invalid /throttlerlogz path", "/throttlerlogz without the slash does not work (the Go HTTP server does automatically redirect in practice though)") } func TestThrottlerlogzHandler_NonExistantThrottler(t *testing.T) { @@ -41,9 +41,8 @@ func TestThrottlerlogzHandler_NonExistantThrottler(t *testing.T) { throttlerlogzHandler(response, request, newManager()) - if got, want := response.Body.String(), `throttler not found`; !strings.Contains(got, want) { - t.Fatalf("/throttlerlogz page for non-existent t1 should not succeed. got = %v, want = %v", got, want) - } + got := response.Body.String() + require.Contains(t, got, "throttler not found", "/throttlerlogz page for non-existent t1 should not succeed") } func TestThrottlerlogzHandler(t *testing.T) { @@ -152,8 +151,6 @@ func TestThrottlerlogzHandler(t *testing.T) { throttlerlogzHandler(response, request, f.m) got := response.Body.String() - if !strings.Contains(got, tc.want) { - t.Fatalf("testcase '%v': result not shown in log. got = %v, want = %v", tc.desc, got, tc.want) - } + require.Containsf(t, got, tc.want, "testcase '%v': result not shown in log", tc.desc) } } diff --git a/go/vt/throttler/throttlerz_test.go b/go/vt/throttler/throttlerz_test.go index be40598468a..9fd95603439 100644 --- a/go/vt/throttler/throttlerz_test.go +++ b/go/vt/throttler/throttlerz_test.go @@ -19,8 +19,9 @@ package throttler import ( "net/http" "net/http/httptest" - "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestThrottlerzHandler_MissingSlash(t *testing.T) { @@ -30,9 +31,8 @@ func TestThrottlerzHandler_MissingSlash(t *testing.T) { throttlerzHandler(response, request, m) - if got, want := response.Body.String(), "invalid /throttlerz path"; !strings.Contains(got, want) { - t.Fatalf("/throttlerz without the slash does not work (the Go HTTP server does automatically redirect in practice though). got = %v, want = %v", got, want) - } + got := response.Body.String() + require.Contains(t, got, "invalid /throttlerz path", "/throttlerz without the slash does not work (the Go HTTP server does automatically redirect in practice though)") } func TestThrottlerzHandler_List(t *testing.T) { @@ -47,12 +47,9 @@ func TestThrottlerzHandler_List(t *testing.T) { throttlerzHandler(response, request, f.m) - if got, want := response.Body.String(), `t1`; !strings.Contains(got, want) { - t.Fatalf("list does not include 't1'. got = %v, want = %v", got, want) - } - if got, want := response.Body.String(), `t2`; !strings.Contains(got, want) { - t.Fatalf("list does not include 't1'. got = %v, want = %v", got, want) - } + got := response.Body.String() + require.Contains(t, got, `t1`, "list does not include 't1'") + require.Contains(t, got, `t2`, "list does not include 't2'") } func TestThrottlerzHandler_Details(t *testing.T) { @@ -67,7 +64,6 @@ func TestThrottlerzHandler_Details(t *testing.T) { throttlerzHandler(response, request, f.m) - if got, want := response.Body.String(), `Details for Throttler 't1'`; !strings.Contains(got, want) { - t.Fatalf("details for 't1' not shown. got = %v, want = %v", got, want) - } + got := response.Body.String() + require.Contains(t, got, `Details for Throttler 't1'`, "details for 't1' not shown") }