Skip to content

Commit

Permalink
admin: PauseIdentifier batch by account and pause in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
beautifulentropy committed Aug 28, 2024
1 parent da7865c commit 9384c09
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 27 deletions.
85 changes: 64 additions & 21 deletions cmd/admin/pause_identifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ import (
"io"
"os"
"strconv"
"sync"

"github.com/letsencrypt/boulder/identifier"
sapb "github.com/letsencrypt/boulder/sa/proto"
"github.com/letsencrypt/boulder/semaphore"
)

// subcommandPauseIdentifier encapsulates the "admin pause-identifiers" command.
type subcommandPauseIdentifier struct {
batchFile string
batchFile string
maxInFlight int64
}

var _ subcommand = (*subcommandPauseIdentifier)(nil)
Expand All @@ -27,6 +30,7 @@ func (p *subcommandPauseIdentifier) Desc() string {

func (p *subcommandPauseIdentifier) Flags(flag *flag.FlagSet) {
flag.StringVar(&p.batchFile, "batch-file", "", "Path to a CSV file containing (account ID, identifier type, identifier value)")
flag.Int64Var(&p.maxInFlight, "max-in-flight", 10, "The maximum number of concurrent pause requests to send to the SA")
}

func (p *subcommandPauseIdentifier) Run(ctx context.Context, a *admin) error {
Expand All @@ -39,39 +43,78 @@ func (p *subcommandPauseIdentifier) Run(ctx context.Context, a *admin) error {
return err
}

_, err = a.pauseIdentifiers(ctx, identifiers)
_, err = a.pauseIdentifiers(ctx, identifiers, p.maxInFlight)
if err != nil {
return err
}

return nil
}

// pauseIdentifiers allows administratively pausing a set of domain names for an
// account. It returns a slice of PauseIdentifiersResponse or an error.
func (a *admin) pauseIdentifiers(ctx context.Context, incoming []pauseCSVData) ([]*sapb.PauseIdentifiersResponse, error) {
if len(incoming) <= 0 {
// pauseIdentifiers pauses each account, identifier pair in the provided slice
// of pauseCSVData entries. It will pause up to maxInFlight identifiers at a
// time. If any errors occur while pausing, they will be gathered and returned
// as a single error.
func (a *admin) pauseIdentifiers(ctx context.Context, entries []pauseCSVData, maxInFlight int64) ([]*sapb.PauseIdentifiersResponse, error) {
if len(entries) <= 0 {
return nil, errors.New("cannot pause identifiers because no pauseData was sent")
}

var responses []*sapb.PauseIdentifiersResponse
for _, data := range incoming {
req := sapb.PauseRequest{
RegistrationID: data.accountID,
Identifiers: []*sapb.Identifier{
{
Type: string(data.identifierType),
Value: data.identifierValue,
},
},
}
response, err := a.sac.PauseIdentifiers(ctx, &req)
if err != nil {
return nil, err
}
accountToIdentifiers := make(map[int64][]*sapb.Identifier)
for _, entry := range entries {
accountToIdentifiers[entry.accountID] = append(accountToIdentifiers[entry.accountID], &sapb.Identifier{
Type: string(entry.identifierType),
Value: entry.identifierValue,
})
}

respChan := make(chan *sapb.PauseIdentifiersResponse, len(accountToIdentifiers))
errorsChan := make(chan error, len(accountToIdentifiers))
sem := semaphore.NewWeighted(maxInFlight, 0)

var wg sync.WaitGroup
for accountID, identifiers := range accountToIdentifiers {
wg.Add(1)
go func(accountID int64, identifiers []*sapb.Identifier) {
defer wg.Done()

err := sem.Acquire(ctx, 1)
if err != nil {
errorsChan <- fmt.Errorf("while acquiring semaphore to pause identifier(s) %q for account %d: %w", identifiers, accountID, err)
return
}
defer sem.Release(1)

response, err := a.sac.PauseIdentifiers(ctx, &sapb.PauseRequest{
RegistrationID: accountID,
Identifiers: identifiers,
})
if err != nil {
errorsChan <- fmt.Errorf("while pausing identifiers %q for account %d: %w", identifiers, accountID, err)
return
}
respChan <- response
}(accountID, identifiers)
}

wg.Wait()
close(respChan)
close(errorsChan)

responses := make([]*sapb.PauseIdentifiersResponse, 0)
for response := range respChan {
responses = append(responses, response)
}

var errors []error
for err := range errorsChan {
errors = append(errors, err)
}

if len(errors) > 0 {
return responses, fmt.Errorf("one or more errors occurred while pausing: %v", errors)
}

return responses, nil
}

Expand Down
16 changes: 10 additions & 6 deletions cmd/admin/pause_identifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,11 @@ func TestPauseIdentifiers(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
data []pauseCSVData
saImpl sapb.StorageAuthorityClient
expectErr bool
name string
data []pauseCSVData
saImpl sapb.StorageAuthorityClient
expectRespLen int
expectErr bool
}{
{
name: "no data",
Expand All @@ -125,6 +126,7 @@ func TestPauseIdentifiers(t *testing.T) {
identifierValue: "example.com",
},
},
expectRespLen: 1,
},
{
name: "valid single entry but broken SA",
Expand Down Expand Up @@ -167,6 +169,7 @@ func TestPauseIdentifiers(t *testing.T) {
identifierValue: "example.org",
},
},
expectRespLen: 3,
},
}

Expand All @@ -182,12 +185,13 @@ func TestPauseIdentifiers(t *testing.T) {
}
a := admin{sac: testCase.saImpl, log: log}

responses, err := a.pauseIdentifiers(context.Background(), testCase.data)
responses, err := a.pauseIdentifiers(context.Background(), testCase.data, 10)
if testCase.expectErr {
test.AssertError(t, err, "should have errored, but did not")
} else {
test.AssertNotError(t, err, "should not have errored")
test.AssertEquals(t, len(responses), len(testCase.data))
// Batching will consolidate identifiers under the same account.
test.AssertEquals(t, len(responses), testCase.expectRespLen)
}
})
}
Expand Down

0 comments on commit 9384c09

Please sign in to comment.