Skip to content

Commit

Permalink
Fix race conditions.
Browse files Browse the repository at this point in the history
  • Loading branch information
beautifulentropy committed Sep 10, 2024
1 parent 2d134fc commit d0fc59e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 26 deletions.
3 changes: 0 additions & 3 deletions cmd/admin/pause_identifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,9 @@ func TestReadingPauseCSV(t *testing.T) {
// PauseIdentifiersResponse. It does not maintain state of repaused identifiers.
type mockSAPaused struct {
sapb.StorageAuthorityClient
reqs []*sapb.PauseRequest
}

func (msa *mockSAPaused) PauseIdentifiers(ctx context.Context, in *sapb.PauseRequest, _ ...grpc.CallOption) (*sapb.PauseIdentifiersResponse, error) {
msa.reqs = append(msa.reqs, in)

return &sapb.PauseIdentifiersResponse{Paused: int64(len(in.Identifiers))}, nil
}

Expand Down
13 changes: 9 additions & 4 deletions cmd/admin/unpause_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"flag"
"fmt"
"os"
"slices"
"strconv"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -81,9 +82,11 @@ type unpauseCount struct {
// up to `parallelism` workers. It returns a count of the number of identifiers
// unpaused for each account and any accumulated errors.
func (a *admin) unpauseAccounts(ctx context.Context, accountIDs []int64, parallelism uint) ([]unpauseCount, error) {
if len(accountIDs) <= 0 {
if len(accountIDs) == 0 {
return nil, errors.New("no account IDs provided for unpausing")
}
slices.Sort(accountIDs)
accountIDs = slices.Compact(accountIDs)

countChan := make(chan unpauseCount, len(accountIDs))
work := make(chan int64, parallelism)
Expand Down Expand Up @@ -112,10 +115,12 @@ func (a *admin) unpauseAccounts(ctx context.Context, accountIDs []int64, paralle
for _, accountID := range accountIDs {
work <- accountID
}

wg.Wait()
close(work)
close(countChan)

go func() {
wg.Wait()
close(countChan)
}()

var unpauseCounts []unpauseCount
for count := range countChan {
Expand Down
33 changes: 14 additions & 19 deletions cmd/admin/unpause_account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,12 @@ func TestReadingUnpauseAccountsFile(t *testing.T) {
}
}

// mockSAPaused is a mock that always succeeds. It records each PauseRequest it
// receives, and returns the number of identifiers as a
// PauseIdentifiersResponse. It does not maintain state of repaused identifiers.
type mockSAUnpause struct {
sapb.StorageAuthorityClient
regIDCounter map[int64]int64
}

func (msa *mockSAUnpause) UnpauseAccount(ctx context.Context, in *sapb.RegistrationID, _ ...grpc.CallOption) (*sapb.Count, error) {
if _, ok := msa.regIDCounter[in.Id]; ok {
msa.regIDCounter[in.Id] += 1
}

return &sapb.Count{Count: msa.regIDCounter[in.Id]}, nil
return &sapb.Count{Count: 1}, nil
}

// mockSAUnpauseBroken is a mock that always returns an error.
Expand All @@ -89,19 +81,21 @@ func TestUnpauseAccounts(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
regIDs []int64
saImpl sapb.StorageAuthorityClient
expectErr bool
name string
regIDs []int64
saImpl sapb.StorageAuthorityClient
expectErr bool
expectCounts int
}{
{
name: "no data",
regIDs: nil,
expectErr: true,
},
{
name: "valid single entry",
regIDs: []int64{1},
name: "valid single entry",
regIDs: []int64{1},
expectCounts: 1,
},
{
name: "valid single entry but broken SA",
Expand All @@ -110,8 +104,9 @@ func TestUnpauseAccounts(t *testing.T) {
regIDs: []int64{1},
},
{
name: "valid multiple entries with duplicates",
regIDs: []int64{1, 1, 2, 3, 4},
name: "valid multiple entries with duplicates",
regIDs: []int64{1, 1, 2, 3, 4},
expectCounts: 4,
},
}

Expand All @@ -127,12 +122,12 @@ func TestUnpauseAccounts(t *testing.T) {
}
a := admin{sac: testCase.saImpl, log: log}

count, err := a.unpauseAccounts(context.Background(), testCase.regIDs, 10)
counts, err := a.unpauseAccounts(context.Background(), testCase.regIDs, 10)
if testCase.expectErr {
test.AssertError(t, err, "should have errored, but did not")
} else {
test.AssertNotError(t, err, "should not have errored")
test.AssertEquals(t, len(count), len(testCase.regIDs))
test.AssertEquals(t, testCase.expectCounts, len(counts))
}
})
}
Expand Down

0 comments on commit d0fc59e

Please sign in to comment.