From 2d134fc12a648510b862ff9ce1e6d591d57da6f8 Mon Sep 17 00:00:00 2001 From: Samantha Date: Thu, 29 Aug 2024 14:45:15 -0400 Subject: [PATCH 1/2] admin: Perform unpauseAccount batches in parallel --- cmd/admin/unpause_account.go | 84 +++++++++++++++++++++++-------- cmd/admin/unpause_account_test.go | 2 +- 2 files changed, 65 insertions(+), 21 deletions(-) diff --git a/cmd/admin/unpause_account.go b/cmd/admin/unpause_account.go index 3ec2306ba13..3636b7bb6ae 100644 --- a/cmd/admin/unpause_account.go +++ b/cmd/admin/unpause_account.go @@ -8,15 +8,19 @@ import ( "fmt" "os" "strconv" + "sync" + "sync/atomic" sapb "github.com/letsencrypt/boulder/sa/proto" + "github.com/letsencrypt/boulder/unpause" "golang.org/x/exp/maps" ) // subcommandUnpauseAccount encapsulates the "admin unpause-account" command. type subcommandUnpauseAccount struct { - batchFile string - regID int64 + accountID int64 + batchFile string + parallelism uint } var _ subcommand = (*subcommandUnpauseAccount)(nil) @@ -26,8 +30,9 @@ func (u *subcommandUnpauseAccount) Desc() string { } func (u *subcommandUnpauseAccount) Flags(flag *flag.FlagSet) { + flag.Int64Var(&u.accountID, "account", 0, "A single account ID to unpause") flag.StringVar(&u.batchFile, "batch-file", "", "Path to a file containing multiple account IDs where each is separated by a newline") - flag.Int64Var(&u.regID, "account", 0, "A single account ID to unpause") + flag.UintVar(&u.parallelism, "parallelism", 10, "The maximum number of concurrent unpause requests to send to the SA (default: 10)") } func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { @@ -35,7 +40,7 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { // to a non-default value. We use this to ensure that exactly one input // selection flag was given on the command line. setInputs := map[string]bool{ - "-account": u.regID != 0, + "-account": u.accountID != 0, "-batch-file": u.batchFile != "", } maps.DeleteFunc(setInputs, func(_ string, v bool) bool { return !v }) @@ -49,7 +54,7 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { var err error switch maps.Keys(setInputs)[0] { case "-account": - regIDs = []int64{u.regID} + regIDs = []int64{u.accountID} case "-batch-file": regIDs, err = a.readUnpauseAccountFile(u.batchFile) default: @@ -59,7 +64,7 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { return fmt.Errorf("collecting serials to revoke: %w", err) } - _, err = a.unpauseAccounts(ctx, regIDs) + _, err = a.unpauseAccounts(ctx, regIDs, u.parallelism) if err != nil { return err } @@ -67,24 +72,63 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { return nil } -// unpauseAccount allows administratively unpausing all identifiers for an -// account. Returns a slice of int64 which is counter of unpaused accounts or an -// error. -func (a *admin) unpauseAccounts(ctx context.Context, regIDs []int64) ([]int64, error) { - var count []int64 - if len(regIDs) <= 0 { - return count, errors.New("no regIDs sent for unpausing") +type unpauseCount struct { + accountID int64 + count int64 +} + +// unpauseAccount concurrently unpauses all identifiers for each account using +// up to `parallelism` workers. It returns a count of the number of identifiers +// unpaused for each account and any accumulated errors. +func (a *admin) unpauseAccounts(ctx context.Context, accountIDs []int64, parallelism uint) ([]unpauseCount, error) { + if len(accountIDs) <= 0 { + return nil, errors.New("no account IDs provided for unpausing") } - for _, regID := range regIDs { - response, err := a.sac.UnpauseAccount(ctx, &sapb.RegistrationID{Id: regID}) - if err != nil { - return count, err - } - count = append(count, response.Count) + countChan := make(chan unpauseCount, len(accountIDs)) + work := make(chan int64, parallelism) + + var wg sync.WaitGroup + var errCount atomic.Uint64 + for i := uint(0); i < parallelism; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for accountID := range work { + response, err := a.sac.UnpauseAccount(ctx, &sapb.RegistrationID{Id: accountID}) + if err != nil { + errCount.Add(1) + a.log.Errf("error unpausing accountID %d: %v", accountID, err) + continue + } + if response.Count >= unpause.RequestLimit { + work <- accountID + } + countChan <- unpauseCount{accountID: accountID, count: response.Count} + } + }() + } + + for _, accountID := range accountIDs { + work <- accountID + } + + wg.Wait() + close(work) + close(countChan) + + var unpauseCounts []unpauseCount + for count := range countChan { + // There could be multiple unpause requests for the same account ID if + // the account has more than `unpause.RequestLimit` identifiers. + unpauseCounts = append(unpauseCounts, count) + } + + if errCount.Load() > 0 { + return unpauseCounts, fmt.Errorf("encountered %d errors while unpausing; see logs above for details", errCount.Load()) } - return count, nil + return unpauseCounts, nil } // readUnpauseAccountFile parses the contents of a file containing one account diff --git a/cmd/admin/unpause_account_test.go b/cmd/admin/unpause_account_test.go index 4eed71f5f87..c81cccdc4d2 100644 --- a/cmd/admin/unpause_account_test.go +++ b/cmd/admin/unpause_account_test.go @@ -127,7 +127,7 @@ func TestUnpauseAccounts(t *testing.T) { } a := admin{sac: testCase.saImpl, log: log} - count, err := a.unpauseAccounts(context.Background(), testCase.regIDs) + count, err := a.unpauseAccounts(context.Background(), testCase.regIDs, 10) if testCase.expectErr { test.AssertError(t, err, "should have errored, but did not") } else { From c1b492084389ca4e13071d171c5261cbe0788a77 Mon Sep 17 00:00:00 2001 From: Samantha Date: Tue, 10 Sep 2024 17:18:49 -0400 Subject: [PATCH 2/2] Fix race conditions. --- cmd/admin/pause_identifier_test.go | 3 -- cmd/admin/unpause_account.go | 46 ++++++++++++++++++------------ cmd/admin/unpause_account_test.go | 33 +++++++++------------ 3 files changed, 42 insertions(+), 40 deletions(-) diff --git a/cmd/admin/pause_identifier_test.go b/cmd/admin/pause_identifier_test.go index 70dc050d183..937cf179107 100644 --- a/cmd/admin/pause_identifier_test.go +++ b/cmd/admin/pause_identifier_test.go @@ -84,12 +84,9 @@ func TestReadingPauseCSV(t *testing.T) { // PauseIdentifiersResponse. It does not maintain state of repaused identifiers. type mockSAPaused struct { sapb.StorageAuthorityClient - reqs []*sapb.PauseRequest } func (msa *mockSAPaused) PauseIdentifiers(ctx context.Context, in *sapb.PauseRequest, _ ...grpc.CallOption) (*sapb.PauseIdentifiersResponse, error) { - msa.reqs = append(msa.reqs, in) - return &sapb.PauseIdentifiersResponse{Paused: int64(len(in.Identifiers))}, nil } diff --git a/cmd/admin/unpause_account.go b/cmd/admin/unpause_account.go index 3636b7bb6ae..9a56dde05f7 100644 --- a/cmd/admin/unpause_account.go +++ b/cmd/admin/unpause_account.go @@ -7,6 +7,7 @@ import ( "flag" "fmt" "os" + "slices" "strconv" "sync" "sync/atomic" @@ -84,9 +85,11 @@ func (a *admin) unpauseAccounts(ctx context.Context, accountIDs []int64, paralle if len(accountIDs) <= 0 { return nil, errors.New("no account IDs provided for unpausing") } + slices.Sort(accountIDs) + accountIDs = slices.Compact(accountIDs) countChan := make(chan unpauseCount, len(accountIDs)) - work := make(chan int64, parallelism) + work := make(chan int64) var wg sync.WaitGroup var errCount atomic.Uint64 @@ -95,32 +98,39 @@ func (a *admin) unpauseAccounts(ctx context.Context, accountIDs []int64, paralle go func() { defer wg.Done() for accountID := range work { - response, err := a.sac.UnpauseAccount(ctx, &sapb.RegistrationID{Id: accountID}) - if err != nil { - errCount.Add(1) - a.log.Errf("error unpausing accountID %d: %v", accountID, err) - continue + totalCount := int64(0) + for { + response, err := a.sac.UnpauseAccount(ctx, &sapb.RegistrationID{Id: accountID}) + if err != nil { + errCount.Add(1) + a.log.Errf("error unpausing accountID %d: %v", accountID, err) + break + } + totalCount += response.Count + if response.Count < unpause.RequestLimit { + // All identifiers have been unpaused. + break + } } - if response.Count >= unpause.RequestLimit { - work <- accountID - } - countChan <- unpauseCount{accountID: accountID, count: response.Count} + countChan <- unpauseCount{accountID: accountID, count: totalCount} } }() } - for _, accountID := range accountIDs { - work <- accountID - } + go func() { + for _, accountID := range accountIDs { + work <- accountID + } + close(work) + }() - wg.Wait() - close(work) - close(countChan) + go func() { + wg.Wait() + close(countChan) + }() var unpauseCounts []unpauseCount for count := range countChan { - // There could be multiple unpause requests for the same account ID if - // the account has more than `unpause.RequestLimit` identifiers. unpauseCounts = append(unpauseCounts, count) } diff --git a/cmd/admin/unpause_account_test.go b/cmd/admin/unpause_account_test.go index c81cccdc4d2..f39b168fcbf 100644 --- a/cmd/admin/unpause_account_test.go +++ b/cmd/admin/unpause_account_test.go @@ -60,20 +60,12 @@ func TestReadingUnpauseAccountsFile(t *testing.T) { } } -// mockSAPaused is a mock that always succeeds. It records each PauseRequest it -// receives, and returns the number of identifiers as a -// PauseIdentifiersResponse. It does not maintain state of repaused identifiers. type mockSAUnpause struct { sapb.StorageAuthorityClient - regIDCounter map[int64]int64 } func (msa *mockSAUnpause) UnpauseAccount(ctx context.Context, in *sapb.RegistrationID, _ ...grpc.CallOption) (*sapb.Count, error) { - if _, ok := msa.regIDCounter[in.Id]; ok { - msa.regIDCounter[in.Id] += 1 - } - - return &sapb.Count{Count: msa.regIDCounter[in.Id]}, nil + return &sapb.Count{Count: 1}, nil } // mockSAUnpauseBroken is a mock that always returns an error. @@ -89,10 +81,11 @@ func TestUnpauseAccounts(t *testing.T) { t.Parallel() testCases := []struct { - name string - regIDs []int64 - saImpl sapb.StorageAuthorityClient - expectErr bool + name string + regIDs []int64 + saImpl sapb.StorageAuthorityClient + expectErr bool + expectCounts int }{ { name: "no data", @@ -100,8 +93,9 @@ func TestUnpauseAccounts(t *testing.T) { expectErr: true, }, { - name: "valid single entry", - regIDs: []int64{1}, + name: "valid single entry", + regIDs: []int64{1}, + expectCounts: 1, }, { name: "valid single entry but broken SA", @@ -110,8 +104,9 @@ func TestUnpauseAccounts(t *testing.T) { regIDs: []int64{1}, }, { - name: "valid multiple entries with duplicates", - regIDs: []int64{1, 1, 2, 3, 4}, + name: "valid multiple entries with duplicates", + regIDs: []int64{1, 1, 2, 3, 4}, + expectCounts: 4, }, } @@ -127,12 +122,12 @@ func TestUnpauseAccounts(t *testing.T) { } a := admin{sac: testCase.saImpl, log: log} - count, err := a.unpauseAccounts(context.Background(), testCase.regIDs, 10) + counts, err := a.unpauseAccounts(context.Background(), testCase.regIDs, 10) if testCase.expectErr { test.AssertError(t, err, "should have errored, but did not") } else { test.AssertNotError(t, err, "should not have errored") - test.AssertEquals(t, len(count), len(testCase.regIDs)) + test.AssertEquals(t, testCase.expectCounts, len(counts)) } }) }