diff --git a/cmd/admin/pause_identifier_test.go b/cmd/admin/pause_identifier_test.go index 70dc050d183..937cf179107 100644 --- a/cmd/admin/pause_identifier_test.go +++ b/cmd/admin/pause_identifier_test.go @@ -84,12 +84,9 @@ func TestReadingPauseCSV(t *testing.T) { // PauseIdentifiersResponse. It does not maintain state of repaused identifiers. type mockSAPaused struct { sapb.StorageAuthorityClient - reqs []*sapb.PauseRequest } func (msa *mockSAPaused) PauseIdentifiers(ctx context.Context, in *sapb.PauseRequest, _ ...grpc.CallOption) (*sapb.PauseIdentifiersResponse, error) { - msa.reqs = append(msa.reqs, in) - return &sapb.PauseIdentifiersResponse{Paused: int64(len(in.Identifiers))}, nil } diff --git a/cmd/admin/unpause_account.go b/cmd/admin/unpause_account.go index 3ec2306ba13..9a56dde05f7 100644 --- a/cmd/admin/unpause_account.go +++ b/cmd/admin/unpause_account.go @@ -7,16 +7,21 @@ import ( "flag" "fmt" "os" + "slices" "strconv" + "sync" + "sync/atomic" sapb "github.com/letsencrypt/boulder/sa/proto" + "github.com/letsencrypt/boulder/unpause" "golang.org/x/exp/maps" ) // subcommandUnpauseAccount encapsulates the "admin unpause-account" command. type subcommandUnpauseAccount struct { - batchFile string - regID int64 + accountID int64 + batchFile string + parallelism uint } var _ subcommand = (*subcommandUnpauseAccount)(nil) @@ -26,8 +31,9 @@ func (u *subcommandUnpauseAccount) Desc() string { } func (u *subcommandUnpauseAccount) Flags(flag *flag.FlagSet) { + flag.Int64Var(&u.accountID, "account", 0, "A single account ID to unpause") flag.StringVar(&u.batchFile, "batch-file", "", "Path to a file containing multiple account IDs where each is separated by a newline") - flag.Int64Var(&u.regID, "account", 0, "A single account ID to unpause") + flag.UintVar(&u.parallelism, "parallelism", 10, "The maximum number of concurrent unpause requests to send to the SA (default: 10)") } func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { @@ -35,7 +41,7 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { // to a non-default value. We use this to ensure that exactly one input // selection flag was given on the command line. setInputs := map[string]bool{ - "-account": u.regID != 0, + "-account": u.accountID != 0, "-batch-file": u.batchFile != "", } maps.DeleteFunc(setInputs, func(_ string, v bool) bool { return !v }) @@ -49,7 +55,7 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { var err error switch maps.Keys(setInputs)[0] { case "-account": - regIDs = []int64{u.regID} + regIDs = []int64{u.accountID} case "-batch-file": regIDs, err = a.readUnpauseAccountFile(u.batchFile) default: @@ -59,7 +65,7 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { return fmt.Errorf("collecting serials to revoke: %w", err) } - _, err = a.unpauseAccounts(ctx, regIDs) + _, err = a.unpauseAccounts(ctx, regIDs, u.parallelism) if err != nil { return err } @@ -67,24 +73,72 @@ func (u *subcommandUnpauseAccount) Run(ctx context.Context, a *admin) error { return nil } -// unpauseAccount allows administratively unpausing all identifiers for an -// account. Returns a slice of int64 which is counter of unpaused accounts or an -// error. -func (a *admin) unpauseAccounts(ctx context.Context, regIDs []int64) ([]int64, error) { - var count []int64 - if len(regIDs) <= 0 { - return count, errors.New("no regIDs sent for unpausing") +type unpauseCount struct { + accountID int64 + count int64 +} + +// unpauseAccount concurrently unpauses all identifiers for each account using +// up to `parallelism` workers. It returns a count of the number of identifiers +// unpaused for each account and any accumulated errors. +func (a *admin) unpauseAccounts(ctx context.Context, accountIDs []int64, parallelism uint) ([]unpauseCount, error) { + if len(accountIDs) <= 0 { + return nil, errors.New("no account IDs provided for unpausing") + } + slices.Sort(accountIDs) + accountIDs = slices.Compact(accountIDs) + + countChan := make(chan unpauseCount, len(accountIDs)) + work := make(chan int64) + + var wg sync.WaitGroup + var errCount atomic.Uint64 + for i := uint(0); i < parallelism; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for accountID := range work { + totalCount := int64(0) + for { + response, err := a.sac.UnpauseAccount(ctx, &sapb.RegistrationID{Id: accountID}) + if err != nil { + errCount.Add(1) + a.log.Errf("error unpausing accountID %d: %v", accountID, err) + break + } + totalCount += response.Count + if response.Count < unpause.RequestLimit { + // All identifiers have been unpaused. + break + } + } + countChan <- unpauseCount{accountID: accountID, count: totalCount} + } + }() } - for _, regID := range regIDs { - response, err := a.sac.UnpauseAccount(ctx, &sapb.RegistrationID{Id: regID}) - if err != nil { - return count, err + go func() { + for _, accountID := range accountIDs { + work <- accountID } - count = append(count, response.Count) + close(work) + }() + + go func() { + wg.Wait() + close(countChan) + }() + + var unpauseCounts []unpauseCount + for count := range countChan { + unpauseCounts = append(unpauseCounts, count) + } + + if errCount.Load() > 0 { + return unpauseCounts, fmt.Errorf("encountered %d errors while unpausing; see logs above for details", errCount.Load()) } - return count, nil + return unpauseCounts, nil } // readUnpauseAccountFile parses the contents of a file containing one account diff --git a/cmd/admin/unpause_account_test.go b/cmd/admin/unpause_account_test.go index 4eed71f5f87..f39b168fcbf 100644 --- a/cmd/admin/unpause_account_test.go +++ b/cmd/admin/unpause_account_test.go @@ -60,20 +60,12 @@ func TestReadingUnpauseAccountsFile(t *testing.T) { } } -// mockSAPaused is a mock that always succeeds. It records each PauseRequest it -// receives, and returns the number of identifiers as a -// PauseIdentifiersResponse. It does not maintain state of repaused identifiers. type mockSAUnpause struct { sapb.StorageAuthorityClient - regIDCounter map[int64]int64 } func (msa *mockSAUnpause) UnpauseAccount(ctx context.Context, in *sapb.RegistrationID, _ ...grpc.CallOption) (*sapb.Count, error) { - if _, ok := msa.regIDCounter[in.Id]; ok { - msa.regIDCounter[in.Id] += 1 - } - - return &sapb.Count{Count: msa.regIDCounter[in.Id]}, nil + return &sapb.Count{Count: 1}, nil } // mockSAUnpauseBroken is a mock that always returns an error. @@ -89,10 +81,11 @@ func TestUnpauseAccounts(t *testing.T) { t.Parallel() testCases := []struct { - name string - regIDs []int64 - saImpl sapb.StorageAuthorityClient - expectErr bool + name string + regIDs []int64 + saImpl sapb.StorageAuthorityClient + expectErr bool + expectCounts int }{ { name: "no data", @@ -100,8 +93,9 @@ func TestUnpauseAccounts(t *testing.T) { expectErr: true, }, { - name: "valid single entry", - regIDs: []int64{1}, + name: "valid single entry", + regIDs: []int64{1}, + expectCounts: 1, }, { name: "valid single entry but broken SA", @@ -110,8 +104,9 @@ func TestUnpauseAccounts(t *testing.T) { regIDs: []int64{1}, }, { - name: "valid multiple entries with duplicates", - regIDs: []int64{1, 1, 2, 3, 4}, + name: "valid multiple entries with duplicates", + regIDs: []int64{1, 1, 2, 3, 4}, + expectCounts: 4, }, } @@ -127,12 +122,12 @@ func TestUnpauseAccounts(t *testing.T) { } a := admin{sac: testCase.saImpl, log: log} - count, err := a.unpauseAccounts(context.Background(), testCase.regIDs) + counts, err := a.unpauseAccounts(context.Background(), testCase.regIDs, 10) if testCase.expectErr { test.AssertError(t, err, "should have errored, but did not") } else { test.AssertNotError(t, err, "should not have errored") - test.AssertEquals(t, len(count), len(testCase.regIDs)) + test.AssertEquals(t, testCase.expectCounts, len(counts)) } }) }