diff --git a/main.go b/main.go index 50c9793..ad20aac 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,18 @@ import ( "github.com/google/go-github/v60/github" ) +type ec2RunInstancesAPI interface { + RunInstances(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) +} + +var newEC2Client = func(cfg aws.Config) ec2RunInstancesAPI { + return ec2.NewFromConfig(cfg) +} + +var loadAWSConfig = func(ctx context.Context, optFns ...func(*config.LoadOptions) error) (aws.Config, error) { + return config.LoadDefaultConfig(ctx, optFns...) +} + func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { var githubEventHeader string for k, v := range request.MultiValueHeaders { @@ -32,7 +44,16 @@ func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyRespo slog.Info("no github event header") return events.APIGatewayProxyResponse{StatusCode: 200}, nil } - event, err := github.ParseWebHook(githubEventHeader, []byte(request.Body)) + body := request.Body + if request.IsBase64Encoded { + decoded, err := base64.StdEncoding.DecodeString(request.Body) + if err != nil { + slog.Error("failed to decode body", "error", err.Error()) + return events.APIGatewayProxyResponse{StatusCode: 400}, nil + } + body = string(decoded) + } + event, err := github.ParseWebHook(githubEventHeader, []byte(body)) if err != nil { slog.Error("error parsing webhook", "error", err.Error()) return events.APIGatewayProxyResponse{StatusCode: 200}, nil @@ -43,11 +64,11 @@ func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyRespo slog.Info("not a queued job event") return events.APIGatewayProxyResponse{StatusCode: 200}, nil } - cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion("us-east-2")) + cfg, err := loadAWSConfig(context.TODO(), config.WithRegion("us-east-2")) if err != nil { return events.APIGatewayProxyResponse{StatusCode: 500}, err } - svc := ec2.NewFromConfig(cfg) + svc := newEC2Client(cfg) tags := []types.Tag{ { Key: aws.String("GitHub Workflow Job Event ID"), diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..f4f8bb3 --- /dev/null +++ b/main_test.go @@ -0,0 +1,64 @@ +package main + +import ( + "context" + "encoding/base64" + "testing" + + "github.com/aws/aws-lambda-go/events" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" +) + +type mockEC2Client struct { + input *ec2.RunInstancesInput +} + +func (m *mockEC2Client) RunInstances(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) { + m.input = params + return &ec2.RunInstancesOutput{ + Instances: []types.Instance{{InstanceId: aws.String("i-1234567890")}}, + }, nil +} + +func TestHandlerQueuedEvent(t *testing.T) { + origNew := newEC2Client + origLoad := loadAWSConfig + defer func() { + newEC2Client = origNew + loadAWSConfig = origLoad + }() + + mockSvc := &mockEC2Client{} + newEC2Client = func(cfg aws.Config) ec2RunInstancesAPI { return mockSvc } + loadAWSConfig = func(ctx context.Context, optFns ...func(*config.LoadOptions) error) (aws.Config, error) { + return aws.Config{}, nil + } + + eventJSON := `{"action":"queued","workflow_job":{"id":1,"labels":["ephemeral"]}}` + encoded := base64.StdEncoding.EncodeToString([]byte(eventJSON)) + + req := events.APIGatewayProxyRequest{ + Body: encoded, + IsBase64Encoded: true, + MultiValueHeaders: map[string][]string{ + "X-GitHub-Event": {"workflow_job"}, + }, + } + + resp, err := handler(req) + if err != nil { + t.Fatalf("handler returned error: %v", err) + } + if resp.StatusCode != 200 { + t.Fatalf("unexpected status %d", resp.StatusCode) + } + if mockSvc.input == nil { + t.Fatal("RunInstances not called") + } + if resp.Body != "i-1234567890" { + t.Fatalf("unexpected body %s", resp.Body) + } +}