Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ import (
"github.com/google/go-github/v60/github"
)

type ec2RunInstancesAPI interface {
RunInstances(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error)
}

var newEC2Client = func(cfg aws.Config) ec2RunInstancesAPI {
return ec2.NewFromConfig(cfg)
}

var loadAWSConfig = func(ctx context.Context, optFns ...func(*config.LoadOptions) error) (aws.Config, error) {
return config.LoadDefaultConfig(ctx, optFns...)
}

func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) {
var githubEventHeader string
for k, v := range request.MultiValueHeaders {
Expand All @@ -32,7 +44,16 @@ func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyRespo
slog.Info("no github event header")
return events.APIGatewayProxyResponse{StatusCode: 200}, nil
}
event, err := github.ParseWebHook(githubEventHeader, []byte(request.Body))
body := request.Body
if request.IsBase64Encoded {
decoded, err := base64.StdEncoding.DecodeString(request.Body)
if err != nil {
slog.Error("failed to decode body", "error", err.Error())
return events.APIGatewayProxyResponse{StatusCode: 400}, nil
}
body = string(decoded)
}
event, err := github.ParseWebHook(githubEventHeader, []byte(body))
if err != nil {
slog.Error("error parsing webhook", "error", err.Error())
return events.APIGatewayProxyResponse{StatusCode: 200}, nil
Expand All @@ -43,11 +64,11 @@ func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyRespo
slog.Info("not a queued job event")
return events.APIGatewayProxyResponse{StatusCode: 200}, nil
}
cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion("us-east-2"))
cfg, err := loadAWSConfig(context.TODO(), config.WithRegion("us-east-2"))
if err != nil {
return events.APIGatewayProxyResponse{StatusCode: 500}, err
}
svc := ec2.NewFromConfig(cfg)
svc := newEC2Client(cfg)
tags := []types.Tag{
{
Key: aws.String("GitHub Workflow Job Event ID"),
Expand Down
64 changes: 64 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package main

import (
"context"
"encoding/base64"
"testing"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
)

type mockEC2Client struct {
input *ec2.RunInstancesInput
}

func (m *mockEC2Client) RunInstances(ctx context.Context, params *ec2.RunInstancesInput, optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error) {
m.input = params
return &ec2.RunInstancesOutput{
Instances: []types.Instance{{InstanceId: aws.String("i-1234567890")}},
}, nil
}

func TestHandlerQueuedEvent(t *testing.T) {
origNew := newEC2Client
origLoad := loadAWSConfig
defer func() {
newEC2Client = origNew
loadAWSConfig = origLoad
}()

mockSvc := &mockEC2Client{}
newEC2Client = func(cfg aws.Config) ec2RunInstancesAPI { return mockSvc }
loadAWSConfig = func(ctx context.Context, optFns ...func(*config.LoadOptions) error) (aws.Config, error) {
return aws.Config{}, nil
}

eventJSON := `{"action":"queued","workflow_job":{"id":1,"labels":["ephemeral"]}}`
encoded := base64.StdEncoding.EncodeToString([]byte(eventJSON))

req := events.APIGatewayProxyRequest{
Body: encoded,
IsBase64Encoded: true,
MultiValueHeaders: map[string][]string{
"X-GitHub-Event": {"workflow_job"},
},
}

resp, err := handler(req)
if err != nil {
t.Fatalf("handler returned error: %v", err)
}
if resp.StatusCode != 200 {
t.Fatalf("unexpected status %d", resp.StatusCode)
}
if mockSvc.input == nil {
t.Fatal("RunInstances not called")
}
if resp.Body != "i-1234567890" {
t.Fatalf("unexpected body %s", resp.Body)
}
}