From 130d4d3f56e0e5e2552aff5c0692ad7c4b361f0c Mon Sep 17 00:00:00 2001 From: Samantha Date: Thu, 29 Aug 2024 14:45:15 -0400 Subject: [PATCH] admin: Perform unpauseAccount batches in parallel --- cmd/admin/unpause_account.go | 83 +++++++++++++++++++++++-------- cmd/admin/unpause_account_test.go | 2 +- 2 files changed, 64 insertions(+), 21 deletions(-) diff --git a/cmd/admin/unpause_account.go b/cmd/admin/unpause_account.go index 3ec2306ba133..1a8027ec5a6e 100644 --- a/cmd/admin/unpause_account.go +++ b/cmd/admin/unpause_account.go @@ -8,15 +8,19 @@ import ( "fmt" "os" "strconv" + "sync" + "sync/atomic" sapb "github.com/letsencrypt/boulder/sa/proto" + "github.com/letsencrypt/boulder/unpause" "golang.org/x/exp/maps" ) // subcommandUnpauseAccount encapsulates the "admin unpause-account" command. type subcommandUnpauseAccount struct { - batchFile string - regID int64 + accountID int64 + batchFile string + parallelism uint } var _ subcommand = (*subcommandUnpauseAccount)(nil) @@ -26,8 +30,9 @@ func (u *subcommandUnpauseAccount) Desc() string { } func (u *subcommandUnpauseAccount) Flags(flag *flag.FlagSet) { + flag.Int64Var(&u.accountID, "account", 0, "A single account ID to unpause") flag.StringVar(&u.batchFile, "batch-file", "", "Path to a file containing multiple account IDs where each is separated by a newline") - flag.Int64Var(&u.regID, "account", 0, "A single account ID to unpause") + flag.UintVar(&u.parallelism, "parallelism", 10, "The maximum number of concurrent unpause requests to send to the SA (default: 10)") } func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { @@ -35,7 +40,7 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { // to a non-default value. We use this to ensure that exactly one input // selection flag was given on the command line. setInputs := map[string]bool{ - "-account": u.regID != 0, + "-account": u.accountID != 0, "-batch-file": u.batchFile != "", } maps.DeleteFunc(setInputs, func(_ string, v bool) bool { return !v }) @@ -49,7 +54,7 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { var err error switch maps.Keys(setInputs)[0] { case "-account": - regIDs = []int64{u.regID} + regIDs = []int64{u.accountID} case "-batch-file": regIDs, err = a.readUnpauseAccountFile(u.batchFile) default: @@ -59,7 +64,7 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { return fmt.Errorf("collecting serials to revoke: %w", err) } - _, err = a.unpauseAccounts(ctx, regIDs) + _, err = a.unpauseAccounts(ctx, regIDs, u.parallelism) if err != nil { return err } @@ -67,24 +72,62 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { return nil } -// unpauseAccount allows administratively unpausing all identifiers for an -// account. Returns a slice of int64 which is counter of unpaused accounts or an -// error. -func (a *admin) unpauseAccounts(ctx context.Context, regIDs []int64) ([]int64, error) { - var count []int64 - if len(regIDs) <= 0 { - return count, errors.New("no regIDs sent for unpausing") +type unpauseCount struct { + accountID int64 + count int64 +} + +// unpauseAccount concurrently unpauses all identifiers for each account using +// up to `parallelism` workers. It returns a count of the number of identifiers +// unpaused for each account and any accumulated errors. +func (a *admin) unpauseAccounts(ctx context.Context, accountIDs []int64, parallelism uint) ([]unpauseCount, error) { + if len(accountIDs) <= 0 { + return nil, errors.New("no account IDs provided for unpausing") } - for _, regID := range regIDs { - response, err := a.sac.UnpauseAccount(ctx, &sapb.RegistrationID{Id: regID}) - if err != nil { - return count, err - } - count = append(count, response.Count) + countChan := make(chan unpauseCount, len(accountIDs)) + work := make(chan int64, parallelism) + + var wg sync.WaitGroup + var errCount atomic.Uint64 + for i := uint(0); i < parallelism; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for accountID := range work { + response, err := a.sac.UnpauseAccount(ctx, &sapb.RegistrationID{Id: accountID}) + if err != nil { + errCount.Add(1) + a.log.Errf("error unpausing accountID %d: %v", accountID, err) + continue + } + if response.Count >= unpause.RequestLimit { + work <- accountID + } + countChan <- unpauseCount{accountID: accountID, count: response.Count} + } + }() + } + + for _, accountID := range accountIDs { + work <- accountID + } + close(work) + wg.Wait() + close(countChan) + + var unpauseCounts []unpauseCount + for count := range countChan { + // There could be multiple unpause requests for the same account ID if + // the account has more than `unpause.RequestLimit` identifiers. + unpauseCounts = append(unpauseCounts, count) + } + + if errCount.Load() > 0 { + return unpauseCounts, fmt.Errorf("encountered %d errors while unpausing; see logs above for details", errCount.Load()) } - return count, nil + return unpauseCounts, nil } // readUnpauseAccountFile parses the contents of a file containing one account diff --git a/cmd/admin/unpause_account_test.go b/cmd/admin/unpause_account_test.go index 4eed71f5f871..c81cccdc4d20 100644 --- a/cmd/admin/unpause_account_test.go +++ b/cmd/admin/unpause_account_test.go @@ -127,7 +127,7 @@ func TestUnpauseAccounts(t *testing.T) { } a := admin{sac: testCase.saImpl, log: log} - count, err := a.unpauseAccounts(context.Background(), testCase.regIDs) + count, err := a.unpauseAccounts(context.Background(), testCase.regIDs, 10) if testCase.expectErr { test.AssertError(t, err, "should have errored, but did not") } else {