diff --git a/cmd/admin/pause_identifier.go b/cmd/admin/pause_identifier.go index 81617e085a1..7308c6a0e4d 100644 --- a/cmd/admin/pause_identifier.go +++ b/cmd/admin/pause_identifier.go @@ -9,14 +9,17 @@ import ( "io" "os" "strconv" + "sync" "github.com/letsencrypt/boulder/identifier" sapb "github.com/letsencrypt/boulder/sa/proto" + "github.com/letsencrypt/boulder/semaphore" ) // subcommandPauseIdentifier encapsulates the "admin pause-identifiers" command. type subcommandPauseIdentifier struct { - batchFile string + batchFile string + maxInFlight int64 } var _ subcommand = (*subcommandPauseIdentifier)(nil) @@ -27,6 +30,7 @@ func (p *subcommandPauseIdentifier) Desc() string { func (p *subcommandPauseIdentifier) Flags(flag *flag.FlagSet) { flag.StringVar(&p.batchFile, "batch-file", "", "Path to a CSV file containing (account ID, identifier type, identifier value)") + flag.Int64Var(&p.maxInFlight, "max-in-flight", 10, "The maximum number of concurrent pause requests to send to the SA") } func (p *subcommandPauseIdentifier) Run(ctx context.Context, a *admin) error { @@ -39,7 +43,7 @@ func (p *subcommandPauseIdentifier) Run(ctx context.Context, a *admin) error { return err } - _, err = a.pauseIdentifiers(ctx, identifiers) + _, err = a.pauseIdentifiers(ctx, identifiers, p.maxInFlight) if err != nil { return err } @@ -47,31 +51,70 @@ func (p *subcommandPauseIdentifier) Run(ctx context.Context, a *admin) error { return nil } -// pauseIdentifiers allows administratively pausing a set of domain names for an -// account. It returns a slice of PauseIdentifiersResponse or an error. -func (a *admin) pauseIdentifiers(ctx context.Context, incoming []pauseCSVData) ([]*sapb.PauseIdentifiersResponse, error) { - if len(incoming) <= 0 { +// pauseIdentifiers pauses each account, identifier pair in the provided slice +// of pauseCSVData entries. It will pause up to maxInFlight identifiers at a +// time. If any errors occur while pausing, they will be gathered and returned +// as a single error. +func (a *admin) pauseIdentifiers(ctx context.Context, entries []pauseCSVData, maxInFlight int64) ([]*sapb.PauseIdentifiersResponse, error) { + if len(entries) <= 0 { return nil, errors.New("cannot pause identifiers because no pauseData was sent") } - var responses []*sapb.PauseIdentifiersResponse - for _, data := range incoming { - req := sapb.PauseRequest{ - RegistrationID: data.accountID, - Identifiers: []*sapb.Identifier{ - { - Type: string(data.identifierType), - Value: data.identifierValue, - }, - }, - } - response, err := a.sac.PauseIdentifiers(ctx, &req) - if err != nil { - return nil, err - } + accountToIdentifiers := make(map[int64][]*sapb.Identifier) + for _, entry := range entries { + accountToIdentifiers[entry.accountID] = append(accountToIdentifiers[entry.accountID], &sapb.Identifier{ + Type: string(entry.identifierType), + Value: entry.identifierValue, + }) + } + + respChan := make(chan *sapb.PauseIdentifiersResponse, len(accountToIdentifiers)) + errorsChan := make(chan error, len(accountToIdentifiers)) + sem := semaphore.NewWeighted(maxInFlight, 0) + + var wg sync.WaitGroup + for accountID, identifiers := range accountToIdentifiers { + wg.Add(1) + go func(accountID int64, identifiers []*sapb.Identifier) { + defer wg.Done() + + err := sem.Acquire(ctx, 1) + if err != nil { + errorsChan <- fmt.Errorf("while acquiring semaphore to pause identifier(s) %q for account %d: %w", identifiers, accountID, err) + return + } + defer sem.Release(1) + + response, err := a.sac.PauseIdentifiers(ctx, &sapb.PauseRequest{ + RegistrationID: accountID, + Identifiers: identifiers, + }) + if err != nil { + errorsChan <- fmt.Errorf("while pausing identifiers %q for account %d: %w", identifiers, accountID, err) + return + } + respChan <- response + }(accountID, identifiers) + } + + wg.Wait() + close(respChan) + close(errorsChan) + + responses := make([]*sapb.PauseIdentifiersResponse, 0) + for response := range respChan { responses = append(responses, response) } + var errors []error + for err := range errorsChan { + errors = append(errors, err) + } + + if len(errors) > 0 { + return responses, fmt.Errorf("one or more errors occurred while pausing: %v", errors) + } + return responses, nil } diff --git a/cmd/admin/pause_identifier_test.go b/cmd/admin/pause_identifier_test.go index 588003a0700..70dc050d183 100644 --- a/cmd/admin/pause_identifier_test.go +++ b/cmd/admin/pause_identifier_test.go @@ -106,10 +106,11 @@ func TestPauseIdentifiers(t *testing.T) { t.Parallel() testCases := []struct { - name string - data []pauseCSVData - saImpl sapb.StorageAuthorityClient - expectErr bool + name string + data []pauseCSVData + saImpl sapb.StorageAuthorityClient + expectRespLen int + expectErr bool }{ { name: "no data", @@ -125,6 +126,7 @@ func TestPauseIdentifiers(t *testing.T) { identifierValue: "example.com", }, }, + expectRespLen: 1, }, { name: "valid single entry but broken SA", @@ -167,6 +169,7 @@ func TestPauseIdentifiers(t *testing.T) { identifierValue: "example.org", }, }, + expectRespLen: 3, }, } @@ -182,12 +185,13 @@ func TestPauseIdentifiers(t *testing.T) { } a := admin{sac: testCase.saImpl, log: log} - responses, err := a.pauseIdentifiers(context.Background(), testCase.data) + responses, err := a.pauseIdentifiers(context.Background(), testCase.data, 10) if testCase.expectErr { test.AssertError(t, err, "should have errored, but did not") } else { test.AssertNotError(t, err, "should not have errored") - test.AssertEquals(t, len(responses), len(testCase.data)) + // Batching will consolidate identifiers under the same account. + test.AssertEquals(t, len(responses), testCase.expectRespLen) } }) }