diff --git a/cmd/admin/pause_identifier.go b/cmd/admin/pause_identifier.go index 81617e085a1..fd04a491a5e 100644 --- a/cmd/admin/pause_identifier.go +++ b/cmd/admin/pause_identifier.go @@ -9,6 +9,8 @@ import ( "io" "os" "strconv" + "sync" + "sync/atomic" "github.com/letsencrypt/boulder/identifier" sapb "github.com/letsencrypt/boulder/sa/proto" @@ -16,7 +18,8 @@ import ( // subcommandPauseIdentifier encapsulates the "admin pause-identifiers" command. type subcommandPauseIdentifier struct { - batchFile string + batchFile string + parallelism uint } var _ subcommand = (*subcommandPauseIdentifier)(nil) @@ -27,6 +30,7 @@ func (p *subcommandPauseIdentifier) Desc() string { func (p *subcommandPauseIdentifier) Flags(flag *flag.FlagSet) { flag.StringVar(&p.batchFile, "batch-file", "", "Path to a CSV file containing (account ID, identifier type, identifier value)") + flag.UintVar(&p.parallelism, "parallelism", 10, "The maximum number of concurrent pause requests to send to the SA (default: 10)") } func (p *subcommandPauseIdentifier) Run(ctx context.Context, a *admin) error { @@ -39,7 +43,7 @@ func (p *subcommandPauseIdentifier) Run(ctx context.Context, a *admin) error { return err } - _, err = a.pauseIdentifiers(ctx, identifiers) + _, err = a.pauseIdentifiers(ctx, identifiers, p.parallelism) if err != nil { return err } @@ -47,31 +51,68 @@ func (p *subcommandPauseIdentifier) Run(ctx context.Context, a *admin) error { return nil } -// pauseIdentifiers allows administratively pausing a set of domain names for an -// account. It returns a slice of PauseIdentifiersResponse or an error. -func (a *admin) pauseIdentifiers(ctx context.Context, incoming []pauseCSVData) ([]*sapb.PauseIdentifiersResponse, error) { - if len(incoming) <= 0 { +// pauseIdentifiers concurrently pauses identifiers for each account using up to +// `parallelism` workers. It returns all pause responses and any accumulated +// errors. +func (a *admin) pauseIdentifiers(ctx context.Context, entries []pauseCSVData, parallelism uint) ([]*sapb.PauseIdentifiersResponse, error) { + if len(entries) <= 0 { return nil, errors.New("cannot pause identifiers because no pauseData was sent") } + accountToIdentifiers := make(map[int64][]*sapb.Identifier) + for _, entry := range entries { + accountToIdentifiers[entry.accountID] = append(accountToIdentifiers[entry.accountID], &sapb.Identifier{ + Type: string(entry.identifierType), + Value: entry.identifierValue, + }) + } + + var errCount atomic.Uint64 + respChan := make(chan *sapb.PauseIdentifiersResponse, len(accountToIdentifiers)) + work := make(chan struct { + accountID int64 + identifiers []*sapb.Identifier + }, parallelism) + + var wg sync.WaitGroup + for i := uint(0); i < parallelism; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for data := range work { + response, err := a.sac.PauseIdentifiers(ctx, &sapb.PauseRequest{ + RegistrationID: data.accountID, + Identifiers: data.identifiers, + }) + if err != nil { + errCount.Add(1) + a.log.Errf("error pausing identifier(s) %q for account %d: %v", data.identifiers, data.accountID, err) + } else { + respChan <- response + } + } + }() + } + + for accountID, identifiers := range accountToIdentifiers { + work <- struct { + accountID int64 + identifiers []*sapb.Identifier + }{accountID, identifiers} + } + close(work) + wg.Wait() + close(respChan) + var responses []*sapb.PauseIdentifiersResponse - for _, data := range incoming { - req := sapb.PauseRequest{ - RegistrationID: data.accountID, - Identifiers: []*sapb.Identifier{ - { - Type: string(data.identifierType), - Value: data.identifierValue, - }, - }, - } - response, err := a.sac.PauseIdentifiers(ctx, &req) - if err != nil { - return nil, err - } + for response := range respChan { responses = append(responses, response) } + if errCount.Load() > 0 { + return responses, fmt.Errorf("encountered %d errors while pausing identifiers; see logs above for details", errCount.Load()) + } + return responses, nil } diff --git a/cmd/admin/pause_identifier_test.go b/cmd/admin/pause_identifier_test.go index 588003a0700..70dc050d183 100644 --- a/cmd/admin/pause_identifier_test.go +++ b/cmd/admin/pause_identifier_test.go @@ -106,10 +106,11 @@ func TestPauseIdentifiers(t *testing.T) { t.Parallel() testCases := []struct { - name string - data []pauseCSVData - saImpl sapb.StorageAuthorityClient - expectErr bool + name string + data []pauseCSVData + saImpl sapb.StorageAuthorityClient + expectRespLen int + expectErr bool }{ { name: "no data", @@ -125,6 +126,7 @@ func TestPauseIdentifiers(t *testing.T) { identifierValue: "example.com", }, }, + expectRespLen: 1, }, { name: "valid single entry but broken SA", @@ -167,6 +169,7 @@ func TestPauseIdentifiers(t *testing.T) { identifierValue: "example.org", }, }, + expectRespLen: 3, }, } @@ -182,12 +185,13 @@ func TestPauseIdentifiers(t *testing.T) { } a := admin{sac: testCase.saImpl, log: log} - responses, err := a.pauseIdentifiers(context.Background(), testCase.data) + responses, err := a.pauseIdentifiers(context.Background(), testCase.data, 10) if testCase.expectErr { test.AssertError(t, err, "should have errored, but did not") } else { test.AssertNotError(t, err, "should not have errored") - test.AssertEquals(t, len(responses), len(testCase.data)) + // Batching will consolidate identifiers under the same account. + test.AssertEquals(t, len(responses), testCase.expectRespLen) } }) }