Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

admin: PauseIdentifier batch by account and pause in parallel #7689

Merged
merged 2 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 61 additions & 20 deletions cmd/admin/pause_identifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ import (
"io"
"os"
"strconv"
"sync"
"sync/atomic"

"github.com/letsencrypt/boulder/identifier"
sapb "github.com/letsencrypt/boulder/sa/proto"
)

// subcommandPauseIdentifier encapsulates the "admin pause-identifiers" command.
type subcommandPauseIdentifier struct {
batchFile string
batchFile string
parallelism uint
}

var _ subcommand = (*subcommandPauseIdentifier)(nil)
Expand All @@ -27,6 +30,7 @@ func (p *subcommandPauseIdentifier) Desc() string {

func (p *subcommandPauseIdentifier) Flags(flag *flag.FlagSet) {
flag.StringVar(&p.batchFile, "batch-file", "", "Path to a CSV file containing (account ID, identifier type, identifier value)")
flag.UintVar(&p.parallelism, "parallelism", 10, "The maximum number of concurrent pause requests to send to the SA (default: 10)")
}

func (p *subcommandPauseIdentifier) Run(ctx context.Context, a *admin) error {
Expand All @@ -39,39 +43,76 @@ func (p *subcommandPauseIdentifier) Run(ctx context.Context, a *admin) error {
return err
}

_, err = a.pauseIdentifiers(ctx, identifiers)
_, err = a.pauseIdentifiers(ctx, identifiers, p.parallelism)
if err != nil {
return err
}

return nil
}

// pauseIdentifiers allows administratively pausing a set of domain names for an
// account. It returns a slice of PauseIdentifiersResponse or an error.
func (a *admin) pauseIdentifiers(ctx context.Context, incoming []pauseCSVData) ([]*sapb.PauseIdentifiersResponse, error) {
if len(incoming) <= 0 {
// pauseIdentifiers concurrently pauses identifiers for each account using up to
// `parallelism` workers. It returns all pause responses and any accumulated
// errors.
func (a *admin) pauseIdentifiers(ctx context.Context, entries []pauseCSVData, parallelism uint) ([]*sapb.PauseIdentifiersResponse, error) {
if len(entries) <= 0 {
return nil, errors.New("cannot pause identifiers because no pauseData was sent")
}

accountToIdentifiers := make(map[int64][]*sapb.Identifier)
for _, entry := range entries {
accountToIdentifiers[entry.accountID] = append(accountToIdentifiers[entry.accountID], &sapb.Identifier{
Type: string(entry.identifierType),
Value: entry.identifierValue,
})
}

var errCount atomic.Uint64
respChan := make(chan *sapb.PauseIdentifiersResponse, len(accountToIdentifiers))
work := make(chan struct {
accountID int64
identifiers []*sapb.Identifier
}, parallelism)

var wg sync.WaitGroup
for i := uint(0); i < parallelism; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for data := range work {
response, err := a.sac.PauseIdentifiers(ctx, &sapb.PauseRequest{
RegistrationID: data.accountID,
Identifiers: data.identifiers,
})
if err != nil {
errCount.Add(1)
a.log.Errf("error pausing identifier(s) %q for account %d: %v", data.identifiers, data.accountID, err)
} else {
respChan <- response
}
}
}()
}

for accountID, identifiers := range accountToIdentifiers {
work <- struct {
accountID int64
identifiers []*sapb.Identifier
}{accountID, identifiers}
}
close(work)
wg.Wait()
close(respChan)

var responses []*sapb.PauseIdentifiersResponse
for _, data := range incoming {
req := sapb.PauseRequest{
RegistrationID: data.accountID,
Identifiers: []*sapb.Identifier{
{
Type: string(data.identifierType),
Value: data.identifierValue,
},
},
}
response, err := a.sac.PauseIdentifiers(ctx, &req)
if err != nil {
return nil, err
}
for response := range respChan {
responses = append(responses, response)
}

if errCount.Load() > 0 {
return responses, fmt.Errorf("encountered %d errors while pausing identifiers; see logs above for details", errCount.Load())
}

return responses, nil
}

Expand Down
16 changes: 10 additions & 6 deletions cmd/admin/pause_identifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,11 @@ func TestPauseIdentifiers(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
data []pauseCSVData
saImpl sapb.StorageAuthorityClient
expectErr bool
name string
data []pauseCSVData
saImpl sapb.StorageAuthorityClient
expectRespLen int
expectErr bool
}{
{
name: "no data",
Expand All @@ -125,6 +126,7 @@ func TestPauseIdentifiers(t *testing.T) {
identifierValue: "example.com",
},
},
expectRespLen: 1,
},
{
name: "valid single entry but broken SA",
Expand Down Expand Up @@ -167,6 +169,7 @@ func TestPauseIdentifiers(t *testing.T) {
identifierValue: "example.org",
},
},
expectRespLen: 3,
},
}

Expand All @@ -182,12 +185,13 @@ func TestPauseIdentifiers(t *testing.T) {
}
a := admin{sac: testCase.saImpl, log: log}

responses, err := a.pauseIdentifiers(context.Background(), testCase.data)
responses, err := a.pauseIdentifiers(context.Background(), testCase.data, 10)
if testCase.expectErr {
test.AssertError(t, err, "should have errored, but did not")
} else {
test.AssertNotError(t, err, "should not have errored")
test.AssertEquals(t, len(responses), len(testCase.data))
// Batching will consolidate identifiers under the same account.
test.AssertEquals(t, len(responses), testCase.expectRespLen)
}
})
}
Expand Down
Loading