diff --git a/go.mod b/go.mod index 3c57f5d..88a0100 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.23 require ( github.com/aws/aws-lambda-go v1.51.1 - github.com/aws/aws-sdk-go-v2 v1.41.0 + github.com/aws/aws-sdk-go-v2 v1.41.1 github.com/aws/aws-sdk-go-v2/config v1.32.6 github.com/aws/aws-sdk-go-v2/service/ec2 v1.279.0 github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.41.0 @@ -14,12 +14,13 @@ require ( require ( github.com/aws/aws-sdk-go-v2/credentials v1.19.6 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 // indirect github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 // indirect + github.com/aws/aws-sdk-go-v2/service/ssm v1.67.8 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.30.8 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 // indirect diff --git a/go.sum b/go.sum index 44e2647..68253a7 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ github.com/aws/aws-lambda-go v1.51.1 h1:FpqpCK2WOSoq6hJvO9PhN44GzZHWCN3e9DUQgK0B github.com/aws/aws-lambda-go v1.51.1/go.mod h1:dpMpZgvWx5vuQJfBt0zqBha60q7Dd7RfgJv23DymV8A= github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2 v1.41.1 h1:ABlyEARCDLN034NhxlRUSZr4l71mh+T5KAeGh6cerhU= +github.com/aws/aws-sdk-go-v2 v1.41.1/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= github.com/aws/aws-sdk-go-v2/config v1.32.6 h1:hFLBGUKjmLAekvi1evLi5hVvFQtSo3GYwi+Bx4lpJf8= github.com/aws/aws-sdk-go-v2/config v1.32.6/go.mod h1:lcUL/gcd8WyjCrMnxez5OXkO3/rwcNmvfno62tnXNcI= github.com/aws/aws-sdk-go-v2/credentials v1.19.6 h1:F9vWao2TwjV2MyiyVS+duza0NIRtAslgLUM0vTA1ZaE= @@ -10,8 +12,12 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBU github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 h1:xOLELNKGp2vsiteLsvLPwxC+mYmO6OZ8PYgiuPJzF8U= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17/go.mod h1:5M5CI3D12dNOtH3/mk6minaRwI2/37ifCURZISxA/IQ= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 h1:WWLqlh79iO48yLkj1v3ISRNiv+3KdQoZ6JWyfcsyQik= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17/go.mod h1:EhG22vHRrvF8oXSTYStZhJc1aUgKtnJe+aOiFEV90cM= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= github.com/aws/aws-sdk-go-v2/service/ec2 v1.279.0 h1:o7eJKe6VYAnqERPlLAvDW5VKXV6eTKv1oxTpMoDP378= @@ -24,6 +30,8 @@ github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.41.0 h1:vL6rQXcGtFv9q/9eR github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.41.0/go.mod h1:QwEDLD+7EukuEUnbWtiNE8LhgvvmhjZoi4XAppYPtyc= github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ= github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU= +github.com/aws/aws-sdk-go-v2/service/ssm v1.67.8 h1:31Llf5VfrZ78YvYs7sWcS7L2m3waikzRc6q1nYenVS4= +github.com/aws/aws-sdk-go-v2/service/ssm v1.67.8/go.mod h1:/jgaDlU1UImoxTxhRNxXHvBAPqPZQ8oCjcPbbkR6kac= github.com/aws/aws-sdk-go-v2/service/sso v1.30.8 h1:aM/Q24rIlS3bRAhTyFurowU8A0SMyGDtEOY/l/s/1Uw= github.com/aws/aws-sdk-go-v2/service/sso v1.30.8/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE= diff --git a/main.go b/main.go index d7b91ef..27b696f 100644 --- a/main.go +++ b/main.go @@ -5,15 +5,17 @@ import ( "context" _ "embed" "encoding/base64" + "encoding/json" "errors" "fmt" + "io" "log/slog" "net/http" "os" "slices" "strconv" "strings" - "text/template" + "time" "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" @@ -22,13 +24,512 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + "github.com/aws/aws-sdk-go-v2/service/ssm" + ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/google/go-github/v60/github" ) //go:embed user-data.sh var userData string -func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { +// LaunchConfig holds common EC2 launch configuration. +type LaunchConfig struct { + ImageID string + SubnetID string + SecurityGroups []string + KeyName string + InstanceProfileArn string +} + +// parseWarmPoolConfig parses the WARM_POOL_CONFIG environment variable. +// Returns a map of instance type to target pool size. +func parseWarmPoolConfig() map[string]int { + configStr := os.Getenv("WARM_POOL_CONFIG") + if configStr == "" || configStr == "{}" { + return map[string]int{} + } + + var poolConfig map[string]int + if err := json.Unmarshal([]byte(configStr), &poolConfig); err != nil { + slog.Error("failed to parse WARM_POOL_CONFIG", "error", err.Error(), "config", configStr) + return map[string]int{} + } + + return poolConfig +} + +// warmPoolFilters returns the common filters for querying warm pool instances. +func warmPoolFilters(instanceType types.InstanceType, states []string) []types.Filter { + return []types.Filter{ + {Name: aws.String("tag:WarmPool"), Values: []string{"true"}}, + {Name: aws.String("tag:WarmPoolStatus"), Values: []string{"available"}}, + {Name: aws.String("tag:WarmPoolInstanceType"), Values: []string{string(instanceType)}}, + {Name: aws.String("instance-state-name"), Values: states}, + } +} + +// findAvailableWarmInstance searches for a stopped warm pool instance of the requested type. +func findAvailableWarmInstance(ctx context.Context, svc *ec2.Client, instanceType types.InstanceType) (*string, error) { + output, err := svc.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ + Filters: warmPoolFilters(instanceType, []string{"stopped"}), + }) + if err != nil { + return nil, fmt.Errorf("failed to describe instances: %w", err) + } + + for _, reservation := range output.Reservations { + for _, instance := range reservation.Instances { + return instance.InstanceId, nil + } + } + + return nil, nil +} + +// startWarmInstance activates a stopped warm pool instance for a job. +// It sets the activation tag, updates user-data, changes shutdown behavior to TERMINATE, and starts the instance. +func startWarmInstance(ctx context.Context, svc *ec2.Client, instanceID string, jobEventID int64, finalUserData string) error { + // Update tags to mark as activated and in-use + _, err := svc.CreateTags(ctx, &ec2.CreateTagsInput{ + Resources: []string{instanceID}, + Tags: []types.Tag{ + {Key: aws.String("WarmPoolStatus"), Value: aws.String("in-use")}, + {Key: aws.String("WarmPoolActivated"), Value: aws.String("true")}, + {Key: aws.String("GitHub Workflow Job Event ID"), Value: aws.String(strconv.FormatInt(jobEventID, 10))}, + {Key: aws.String("Name"), Value: aws.String("GitHub Workflow Ephemeral Runner")}, + }, + }) + if err != nil { + return fmt.Errorf("failed to update tags: %w", err) + } + + // Change shutdown behavior to TERMINATE so instance terminates after job + _, err = svc.ModifyInstanceAttribute(ctx, &ec2.ModifyInstanceAttributeInput{ + InstanceId: aws.String(instanceID), + InstanceInitiatedShutdownBehavior: &types.AttributeValue{ + Value: aws.String("terminate"), + }, + }) + if err != nil { + return fmt.Errorf("failed to update shutdown behavior: %w", err) + } + + // Update user data with the full setup script wrapped in multipart format + // so it runs on every boot (including this activation) + // Note: BlobAttributeValue handles base64 encoding automatically, don't pre-encode + wrappedUserData := wrapInMultipart(finalUserData) + _, err = svc.ModifyInstanceAttribute(ctx, &ec2.ModifyInstanceAttributeInput{ + InstanceId: aws.String(instanceID), + UserData: &types.BlobAttributeValue{Value: []byte(wrappedUserData)}, + }) + if err != nil { + return fmt.Errorf("failed to update user data: %w", err) + } + + // Start the instance + startOutput, err := svc.StartInstances(ctx, &ec2.StartInstancesInput{ + InstanceIds: []string{instanceID}, + }) + if err != nil { + return fmt.Errorf("failed to start instance: %w", err) + } + + // Verify the instance was actually stopped before we started it + for _, change := range startOutput.StartingInstances { + if *change.InstanceId == instanceID { + if change.PreviousState.Name != types.InstanceStateNameStopped { + return fmt.Errorf("instance %s was not stopped (was %s)", instanceID, change.PreviousState.Name) + } + } + } + + return nil +} + +// countWarmPoolInstances counts available instances in the warm pool for a given type. +func countWarmPoolInstances(ctx context.Context, svc *ec2.Client, instanceType types.InstanceType) (int, error) { + output, err := svc.DescribeInstances(ctx, &ec2.DescribeInstancesInput{ + Filters: warmPoolFilters(instanceType, []string{"stopped", "stopping"}), + }) + if err != nil { + return 0, fmt.Errorf("failed to describe instances: %w", err) + } + + count := 0 + for _, reservation := range output.Reservations { + count += len(reservation.Instances) + } + + return count, nil +} + +// multipartTemplate is the MIME multipart format for user-data that runs on every boot. +const multipartTemplate = `Content-Type: multipart/mixed; boundary="//" +MIME-Version: 1.0 + +--// +Content-Type: text/cloud-config; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit +Content-Disposition: attachment; filename="cloud-config.txt" + +#cloud-config +cloud_final_modules: +- [scripts-user, always] + +--// +Content-Type: text/x-shellscript; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit +Content-Disposition: attachment; filename="userdata.txt" + +%s +--//-- +` + +// wrapInMultipart wraps a shell script in multipart MIME format that runs on every boot. +func wrapInMultipart(script string) string { + return fmt.Sprintf(multipartTemplate, script) +} + +// warmPoolInitUserData is the script that stops the instance on first boot to enter the warm pool. +var warmPoolInitUserData = wrapInMultipart(`#!/bin/bash +# Warm pool init: just stop the instance after first boot +# When activated, user-data will be replaced with the real setup script +echo "Warm pool instance initializing, stopping to enter pool..." +shutdown -h now`) + +// buildRunInstancesInput creates a base RunInstancesInput with common configuration. +func buildRunInstancesInput(instanceType types.InstanceType, launchConfig LaunchConfig, shutdownBehavior types.ShutdownBehavior, tags []types.Tag, userData string) *ec2.RunInstancesInput { + return &ec2.RunInstancesInput{ + MinCount: aws.Int32(1), + MaxCount: aws.Int32(1), + EbsOptimized: aws.Bool(true), + ImageId: aws.String(launchConfig.ImageID), + InstanceInitiatedShutdownBehavior: shutdownBehavior, + InstanceType: instanceType, + IamInstanceProfile: &types.IamInstanceProfileSpecification{ + Arn: aws.String(launchConfig.InstanceProfileArn), + }, + NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{ + { + AssociatePublicIpAddress: aws.Bool(true), + SubnetId: aws.String(launchConfig.SubnetID), + DeleteOnTermination: aws.Bool(true), + DeviceIndex: aws.Int32(0), + Groups: launchConfig.SecurityGroups, + }, + }, + KeyName: aws.String(launchConfig.KeyName), + Monitoring: &types.RunInstancesMonitoringEnabled{Enabled: aws.Bool(true)}, + TagSpecifications: []types.TagSpecification{ + {ResourceType: types.ResourceTypeInstance, Tags: tags}, + {ResourceType: types.ResourceTypeVolume, Tags: tags}, + }, + UserData: aws.String(base64.StdEncoding.EncodeToString([]byte(userData))), + } +} + +// launchInstance runs an EC2 instance and returns its ID. +func launchInstance(ctx context.Context, svc *ec2.Client, input *ec2.RunInstancesInput) (*string, error) { + output, err := svc.RunInstances(ctx, input) + if err != nil { + return nil, err + } + if len(output.Instances) == 0 { + return nil, errors.New("no instance created") + } + return output.Instances[0].InstanceId, nil +} + +// launchWarmPoolInstance launches a new instance destined for the warm pool. +// Instances will stop after first boot and terminate after being used for a job. +func launchWarmPoolInstance(ctx context.Context, svc *ec2.Client, instanceType types.InstanceType, launchConfig LaunchConfig) (*string, error) { + tags := []types.Tag{ + {Key: aws.String("WarmPool"), Value: aws.String("true")}, + {Key: aws.String("WarmPoolStatus"), Value: aws.String("available")}, + {Key: aws.String("WarmPoolActivated"), Value: aws.String("false")}, + {Key: aws.String("WarmPoolInstanceType"), Value: aws.String(string(instanceType))}, + {Key: aws.String("WarmPoolCreatedAt"), Value: aws.String(time.Now().UTC().Format(time.RFC3339))}, + {Key: aws.String("Name"), Value: aws.String(fmt.Sprintf("GitHub Runner Warm Pool - %s", instanceType))}, + } + + input := buildRunInstancesInput(instanceType, launchConfig, types.ShutdownBehaviorStop, tags, warmPoolInitUserData) + instanceID, err := launchInstance(ctx, svc, input) + if err != nil { + return nil, fmt.Errorf("failed to launch warm pool instance: %w", err) + } + return instanceID, nil +} + +// launchFreshInstance launches a new instance that terminates after use. +func launchFreshInstance(ctx context.Context, svc *ec2.Client, instanceType types.InstanceType, launchConfig LaunchConfig, finalUserData string, jobEventID int64) (*string, error) { + tags := []types.Tag{ + {Key: aws.String("GitHub Workflow Job Event ID"), Value: aws.String(strconv.FormatInt(jobEventID, 10))}, + {Key: aws.String("Name"), Value: aws.String("GitHub Workflow Ephemeral Runner")}, + } + + input := buildRunInstancesInput(instanceType, launchConfig, types.ShutdownBehaviorTerminate, tags, finalUserData) + instanceID, err := launchInstance(ctx, svc, input) + if err != nil { + return nil, fmt.Errorf("failed to launch instance: %w", err) + } + return instanceID, nil +} + +// tryAcquireWarmInstance attempts to acquire and start a warm pool instance. +// Returns the instance ID if successful, nil if no instance available or on failure. +func tryAcquireWarmInstance(ctx context.Context, svc *ec2.Client, instanceType types.InstanceType, jobEventID int64, finalUserData string) *string { + warmInstanceID, err := findAvailableWarmInstance(ctx, svc, instanceType) + if err != nil { + slog.Warn("failed to query warm pool", "error", err.Error()) + return nil + } + if warmInstanceID == nil { + slog.Info("no warm pool instance available", "instanceType", instanceType) + return nil + } + + slog.Info("found warm pool instance", "instanceID", *warmInstanceID) + + if err := startWarmInstance(ctx, svc, *warmInstanceID, jobEventID, finalUserData); err != nil { + slog.Error("failed to start warm instance", "instanceID", *warmInstanceID, "error", err.Error()) + // Mark instance for cleanup + _, _ = svc.CreateTags(ctx, &ec2.CreateTagsInput{ + Resources: []string{*warmInstanceID}, + Tags: []types.Tag{{Key: aws.String("WarmPoolStatus"), Value: aws.String("failed")}}, + }) + return nil + } + + slog.Info("started warm pool instance", "instanceID", *warmInstanceID) + return warmInstanceID +} + +// replenishWarmPool launches replacement instances if the pool is below target size. +func replenishWarmPool(ctx context.Context, svc *ec2.Client, instanceType types.InstanceType, launchConfig LaunchConfig, targetSize int) { + currentCount, err := countWarmPoolInstances(ctx, svc, instanceType) + if err != nil { + slog.Warn("failed to count warm pool", "error", err.Error()) + return + } + if currentCount >= targetSize { + return + } + + slog.Info("replenishing warm pool", "instanceType", instanceType, "current", currentCount, "target", targetSize) + + newID, err := launchWarmPoolInstance(ctx, svc, instanceType, launchConfig) + if err != nil { + slog.Warn("failed to launch warm pool replacement", "error", err.Error()) + return + } + slog.Info("launched warm pool replacement", "instanceID", *newID) +} + +// getLaunchConfig builds LaunchConfig from environment variables. +func getLaunchConfig() (LaunchConfig, error) { + subnetID := os.Getenv("SUBNET_ID") + if subnetID == "" { + return LaunchConfig{}, errors.New("SUBNET_ID env var not set") + } + + sgIDs := os.Getenv("SECURITY_GROUP_IDS") + if sgIDs == "" { + return LaunchConfig{}, errors.New("SECURITY_GROUP_IDS env var not set") + } + + keyName := os.Getenv("KEY_NAME") + if keyName == "" { + return LaunchConfig{}, errors.New("KEY_NAME env var not set") + } + + instanceProfileArn := os.Getenv("INSTANCE_PROFILE_ARN") + if instanceProfileArn == "" { + return LaunchConfig{}, errors.New("INSTANCE_PROFILE_ARN env var not set") + } + + imageID := os.Getenv("IMAGE_ID") + if imageID == "" { + return LaunchConfig{}, errors.New("IMAGE_ID env var not set") + } + + return LaunchConfig{ + ImageID: imageID, + SubnetID: subnetID, + SecurityGroups: strings.Split(sgIDs, ","), + KeyName: keyName, + InstanceProfileArn: instanceProfileArn, + }, nil +} + +// JITConfigRequest represents the request body for generating a JIT runner config. +type JITConfigRequest struct { + Name string `json:"name"` + RunnerGroupID int `json:"runner_group_id"` + Labels []string `json:"labels"` + WorkFolder string `json:"work_folder"` +} + +// JITConfigResponse represents the response from the JIT config API. +type JITConfigResponse struct { + Runner struct { + ID int `json:"id"` + Name string `json:"name"` + } `json:"runner"` + EncodedJITConfig string `json:"encoded_jit_config"` +} + +// generateJITConfig calls the GitHub API to generate a JIT runner configuration. +// This eliminates the need for config.sh on the runner, saving 15-30 seconds. +func generateJITConfig(pat, org, runnerName string, labels []string) (*JITConfigResponse, error) { + reqBody := JITConfigRequest{ + Name: runnerName, + RunnerGroupID: 1, // Default runner group + Labels: labels, + WorkFolder: "_work", + } + + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal JIT config request: %w", err) + } + + url := fmt.Sprintf("https://api.github.com/orgs/%s/actions/runners/generate-jitconfig", org) + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "application/vnd.github+json") + req.Header.Set("Authorization", "Bearer "+pat) + req.Header.Set("X-GitHub-Api-Version", "2022-11-28") + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to call JIT config API: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("JIT config API returned status %d: %s", resp.StatusCode, string(body)) + } + + var jitResp JITConfigResponse + if err := json.Unmarshal(body, &jitResp); err != nil { + return nil, fmt.Errorf("failed to parse JIT config response: %w", err) + } + + return &jitResp, nil +} + +// storeJITConfigInSSM stores the JIT config in SSM Parameter Store for the instance to retrieve. +func storeJITConfigInSSM(ctx context.Context, ssmClient *ssm.Client, instanceID, jitConfig string) error { + paramName := fmt.Sprintf("/github-runner/jit-config/%s", instanceID) + + _, err := ssmClient.PutParameter(ctx, &ssm.PutParameterInput{ + Name: aws.String(paramName), + Value: aws.String(jitConfig), + Type: ssmtypes.ParameterTypeSecureString, + Overwrite: aws.Bool(true), + }) + if err != nil { + return fmt.Errorf("failed to store JIT config in SSM: %w", err) + } + + return nil +} + +// handleMaintenance processes scheduled warm pool maintenance events. +func handleMaintenance() error { + slog.Info("warm pool maintenance triggered") + + poolConfig := parseWarmPoolConfig() + if len(poolConfig) == 0 { + slog.Info("warm pool not configured, skipping maintenance") + return nil + } + + ctx := context.TODO() + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion("us-east-2")) + if err != nil { + return fmt.Errorf("failed to load AWS config: %w", err) + } + + svc := ec2.NewFromConfig(cfg) + + launchConfig, err := getLaunchConfig() + if err != nil { + return fmt.Errorf("failed to get launch config: %w", err) + } + + // Check and replenish each configured instance type + for instanceTypeStr, targetSize := range poolConfig { + if targetSize <= 0 { + continue + } + + instanceType := types.InstanceType(instanceTypeStr) + currentCount, err := countWarmPoolInstances(ctx, svc, instanceType) + if err != nil { + slog.Warn("failed to count warm pool", "instanceType", instanceType, "error", err.Error()) + continue + } + + slog.Info("checking warm pool", "instanceType", instanceType, "current", currentCount, "target", targetSize) + + // Launch instances to reach target size + for currentCount < targetSize { + newID, err := launchWarmPoolInstance(ctx, svc, instanceType, launchConfig) + if err != nil { + slog.Error("failed to launch warm pool instance", "instanceType", instanceType, "error", err.Error()) + break + } + slog.Info("launched warm pool instance", "instanceType", instanceType, "instanceID", *newID) + currentCount++ + } + } + + slog.Info("warm pool maintenance complete") + return nil +} + +// MaintenanceEvent represents a scheduled maintenance event from CloudWatch. +type MaintenanceEvent struct { + Source string `json:"source"` +} + +func handler(ctx context.Context, rawEvent json.RawMessage) (interface{}, error) { + // Try to detect if this is a maintenance event + var maintenanceEvent MaintenanceEvent + if err := json.Unmarshal(rawEvent, &maintenanceEvent); err == nil { + if maintenanceEvent.Source == "warmPoolMaintenance" { + if err := handleMaintenance(); err != nil { + slog.Error("maintenance failed", "error", err.Error()) + return nil, err + } + return map[string]string{"status": "ok"}, nil + } + } + + // Otherwise, treat as API Gateway event + var request events.APIGatewayProxyRequest + if err := json.Unmarshal(rawEvent, &request); err != nil { + slog.Error("failed to parse API Gateway event", "error", err.Error()) + return events.APIGatewayProxyResponse{StatusCode: http.StatusBadRequest}, err + } + + return handleWebhook(request) +} + +func handleWebhook(request events.APIGatewayProxyRequest) (events.APIGatewayProxyResponse, error) { var githubEventHeader string for k, v := range request.MultiValueHeaders { @@ -62,13 +563,16 @@ func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyRespo return events.APIGatewayProxyResponse{StatusCode: http.StatusOK}, nil } - cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion("us-east-2")) + ctx := context.TODO() + + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion("us-east-2")) if err != nil { return events.APIGatewayProxyResponse{StatusCode: http.StatusInternalServerError}, err } svc := ec2.NewFromConfig(cfg) sm := secretsmanager.NewFromConfig(cfg) + ssmClient := ssm.NewFromConfig(cfg) secretName := os.Getenv("GITHUB_PAT_SECRET_NAME") if secretName == "" { @@ -77,7 +581,7 @@ func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyRespo return events.APIGatewayProxyResponse{StatusCode: http.StatusInternalServerError}, errors.New("secret name missing") } - secretOut, err := sm.GetSecretValue(context.TODO(), &secretsmanager.GetSecretValueInput{SecretId: aws.String(secretName)}) + secretOut, err := sm.GetSecretValue(ctx, &secretsmanager.GetSecretValueInput{SecretId: aws.String(secretName)}) if err != nil { slog.Error("failed to get secret", "secret", secretName, "error", err.Error()) @@ -91,52 +595,10 @@ func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyRespo extraLabels = "," + extraLabels } - subnetID := os.Getenv("SUBNET_ID") - if subnetID == "" { - slog.Error("SUBNET_ID env var not set") - - return events.APIGatewayProxyResponse{StatusCode: http.StatusInternalServerError}, errors.New("subnet id missing") - } - - sgIDs := os.Getenv("SECURITY_GROUP_IDS") - if sgIDs == "" { - slog.Error("SECURITY_GROUP_IDS env var not set") - - return events.APIGatewayProxyResponse{StatusCode: http.StatusInternalServerError}, errors.New("security groups missing") - } - - securityGroups := strings.Split(sgIDs, ",") - - keyName := os.Getenv("KEY_NAME") - if keyName == "" { - slog.Error("KEY_NAME env var not set") - - return events.APIGatewayProxyResponse{StatusCode: http.StatusInternalServerError}, errors.New("key name missing") - } - - instanceProfileArn := os.Getenv("INSTANCE_PROFILE_ARN") - if instanceProfileArn == "" { - slog.Error("INSTANCE_PROFILE_ARN env var not set") - - return events.APIGatewayProxyResponse{StatusCode: http.StatusInternalServerError}, errors.New("instance profile arn missing") - } - - imageID := os.Getenv("IMAGE_ID") - if imageID == "" { - slog.Error("IMAGE_ID env var not set") - - return events.APIGatewayProxyResponse{StatusCode: http.StatusInternalServerError}, errors.New("image id missing") - } - - tags := []types.Tag{ - { - Key: aws.String("GitHub Workflow Job Event ID"), - Value: aws.String(strconv.Itoa(int(event.GetWorkflowJob().GetID()))), - }, - { - Key: aws.String("Name"), - Value: aws.String("GitHub Workflow Ephemeral Runner"), - }, + launchConfig, err := getLaunchConfig() + if err != nil { + slog.Error("failed to get launch config", "error", err.Error()) + return events.APIGatewayProxyResponse{StatusCode: http.StatusInternalServerError}, err } ephemeral := slices.Contains(event.GetWorkflowJob().Labels, "ephemeral") @@ -159,79 +621,70 @@ func handler(request events.APIGatewayProxyRequest) (events.APIGatewayProxyRespo } } - slog.Info("creating instance", "instanceType", instanceType) + jobEventID := event.GetWorkflowJob().GetID() + runnerName := fmt.Sprintf("ephemeral-i-%d", jobEventID) - tpl, err := template.New("userdata").Parse(userData) - if err != nil { - return events.APIGatewayProxyResponse{StatusCode: http.StatusInternalServerError}, err + slog.Info("processing job", "instanceType", instanceType, "jobID", jobEventID, "runnerName", runnerName) + + // Build labels for the runner + labels := []string{string(instanceType), "ephemeral", "X64"} + if extraLabels != "" { + // extraLabels already has leading comma, split and add non-empty labels + for _, label := range strings.Split(extraLabels, ",") { + if label = strings.TrimSpace(label); label != "" { + labels = append(labels, label) + } + } } - var buf bytes.Buffer - if err := tpl.Execute(&buf, map[string]string{"GitHubPAT": pat, "ExtraLabels": extraLabels}); err != nil { + // Generate JIT config from GitHub API (eliminates need for config.sh on instance) + jitConfig, err := generateJITConfig(pat, "frgrisk", runnerName, labels) + if err != nil { + slog.Error("failed to generate JIT config", "error", err.Error()) return events.APIGatewayProxyResponse{StatusCode: http.StatusInternalServerError}, err } - finalUserData := buf.String() - - output, err := svc.RunInstances( - context.TODO(), - &ec2.RunInstancesInput{ - MinCount: aws.Int32(1), - MaxCount: aws.Int32(1), - EbsOptimized: aws.Bool(true), - ImageId: aws.String(imageID), - InstanceInitiatedShutdownBehavior: types.ShutdownBehaviorTerminate, - InstanceType: instanceType, - IamInstanceProfile: &types.IamInstanceProfileSpecification{ - Arn: aws.String(instanceProfileArn), - }, - NetworkInterfaces: []types.InstanceNetworkInterfaceSpecification{ - { - AssociatePublicIpAddress: aws.Bool(true), - SubnetId: aws.String(subnetID), - DeleteOnTermination: aws.Bool(true), - DeviceIndex: aws.Int32(0), - Groups: securityGroups, - }, - }, - KeyName: aws.String(keyName), - Monitoring: &types.RunInstancesMonitoringEnabled{Enabled: aws.Bool(true)}, - TagSpecifications: []types.TagSpecification{ - { - ResourceType: types.ResourceTypeInstance, - Tags: tags, - }, - { - ResourceType: types.ResourceTypeVolume, - Tags: tags, - }, - }, - // base64 encode user data - UserData: aws.String(base64.StdEncoding.EncodeToString([]byte(finalUserData))), - }, - ) - if err != nil { - slog.Error(err.Error()) + slog.Info("generated JIT config", "runnerID", jitConfig.Runner.ID, "runnerName", jitConfig.Runner.Name) + + // Get warm pool target size for this instance type + poolConfig := parseWarmPoolConfig() + targetSize := poolConfig[string(instanceType)] - return events.APIGatewayProxyResponse{ - Body: err.Error(), - StatusCode: http.StatusInternalServerError, - }, err + // Try warm pool first if configured + var instanceID *string + if targetSize > 0 { + slog.Info("checking warm pool", "instanceType", instanceType, "targetSize", targetSize) + instanceID = tryAcquireWarmInstance(ctx, svc, instanceType, jobEventID, userData) } - if len(output.Instances) == 0 { - slog.Error("no instance created") + // Launch fresh instance if warm pool not available or not configured + if instanceID == nil { + slog.Info("launching fresh instance", "instanceType", instanceType) + instanceID, err = launchFreshInstance(ctx, svc, instanceType, launchConfig, userData, jobEventID) + if err != nil { + slog.Error("failed to launch fresh instance", "error", err.Error()) + return events.APIGatewayProxyResponse{ + Body: err.Error(), + StatusCode: http.StatusInternalServerError, + }, err + } + slog.Info("instance launched", "instanceID", *instanceID) + } - return events.APIGatewayProxyResponse{ - Body: "no instance created", - StatusCode: http.StatusInternalServerError, - }, nil + // Store JIT config in SSM for the instance to retrieve + if err := storeJITConfigInSSM(ctx, ssmClient, *instanceID, jitConfig.EncodedJITConfig); err != nil { + slog.Error("failed to store JIT config in SSM", "error", err.Error()) + return events.APIGatewayProxyResponse{StatusCode: http.StatusInternalServerError}, err } + slog.Info("stored JIT config in SSM", "instanceID", *instanceID) - slog.Info("instance created", "instanceID", output.Instances[0].InstanceId) + // Replenish warm pool if needed + if targetSize > 0 { + replenishWarmPool(ctx, svc, instanceType, launchConfig, targetSize) + } return events.APIGatewayProxyResponse{ - Body: *output.Instances[0].InstanceId, + Body: *instanceID, StatusCode: http.StatusOK, }, nil diff --git a/template.yaml b/template.yaml index ee8f81b..2638591 100644 --- a/template.yaml +++ b/template.yaml @@ -21,11 +21,17 @@ Parameters: KeyName: Type: String Description: EC2 key pair name for the runner + WarmPoolConfig: + Type: String + Default: "{}" + Description: > + JSON map of instance type to pool size. Empty {} disables warm pool. + Example: {"c8a.2xlarge":2,"c8a.4xlarge":1} # More info about Globals: https://github.com/awslabs/serverless-application-model/blob/master/docs/globals.rst Globals: Function: - Timeout: 5 + Timeout: 30 MemorySize: 128 Resources: @@ -54,6 +60,15 @@ Resources: - logs:PutLogEvents - logs:DescribeLogStreams Resource: !Sub 'arn:${AWS::Partition}:logs:${AWS::Region}:${AWS::AccountId}:log-group:/aws/ec2/github-runner:*' + - PolicyName: SSMJITConfigPolicy + PolicyDocument: + Version: '2012-10-17' + Statement: + - Effect: Allow + Action: + - ssm:GetParameter + - ssm:DeleteParameter + Resource: !Sub 'arn:${AWS::Partition}:ssm:${AWS::Region}:${AWS::AccountId}:parameter/github-runner/jit-config/*' RunnerInstanceProfile: Type: AWS::IAM::InstanceProfile @@ -77,6 +92,13 @@ Resources: - method.request.header.X-GitHub-Event: Required: true Caching: false + WarmPoolMaintenance: + Type: Schedule + Properties: + Schedule: rate(5 minutes) + Description: Maintain warm pool of EC2 instances + Enabled: true + Input: '{"source": "warmPoolMaintenance"}' Environment: # More info about Env Vars: https://github.com/awslabs/serverless-application-model/blob/master/versions/2016-10-31.md#environment-object Variables: GITHUB_PAT_SECRET_NAME: !Ref GitHubPATSecretName @@ -86,6 +108,7 @@ Resources: SECURITY_GROUP_IDS: !Ref SecurityGroupIds KEY_NAME: !Ref KeyName INSTANCE_PROFILE_ARN: !GetAtt RunnerInstanceProfile.Arn + WARM_POOL_CONFIG: !Ref WarmPoolConfig Policies: - Statement: - Sid: RunInstances @@ -96,11 +119,25 @@ Resources: - ec2:CreateTags - ec2:RunInstances Resource: "*" + - Sid: WarmPoolManagement + Effect: Allow + Action: + - ec2:DescribeInstances + - ec2:StartInstances + - ec2:StopInstances + - ec2:ModifyInstanceAttribute + Resource: "*" - Sid: GetGitHubPAT Effect: Allow Action: - secretsmanager:GetSecretValue Resource: !Sub arn:${AWS::Partition}:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${GitHubPATSecretName}* + - Sid: SSMJITConfig + Effect: Allow + Action: + - ssm:PutParameter + - ssm:DeleteParameter + Resource: !Sub arn:${AWS::Partition}:ssm:${AWS::Region}:${AWS::AccountId}:parameter/github-runner/jit-config/* Metadata: BuildMethod: makefile diff --git a/user-data.sh b/user-data.sh index 7225372..aec56d5 100644 --- a/user-data.sh +++ b/user-data.sh @@ -57,18 +57,28 @@ sed -i 's/ap-southeast-3/us-east-2/g' /etc/apt/sources.list # Add ubuntu user to docker group usermod -aG docker ubuntu -# Setup runner directory -cd /opt -mkdir -p actions-runner -chown -R ubuntu:ubuntu actions-runner -cd actions-runner - -# Extract runner -log_to_cloudwatch "INFO" "Extracting GitHub runner" -if ! sudo -u ubuntu tar xzf ../runner-cache/actions-runner-linux-* -C .; then - log_to_cloudwatch "ERROR" "Failed to extract runner archive" - shutdown now - exit 1 +# Use pre-extracted runner if available, otherwise extract from cache +if [ -d "/opt/actions-runner" ] && [ -f "/opt/actions-runner/run.sh" ]; then + log_to_cloudwatch "INFO" "Using pre-extracted GitHub runner" + cd /opt/actions-runner +else + log_to_cloudwatch "INFO" "Pre-extracted runner not found, extracting from cache" + + # Find runner archive in cache + RUNNER_ARCHIVE=$(ls /opt/runner-cache/actions-runner-linux-*.tar.gz 2>/dev/null | head -1) + + if [ -z "$RUNNER_ARCHIVE" ]; then + log_to_cloudwatch "ERROR" "No runner archive found in /opt/runner-cache" + shutdown now + exit 1 + fi + + # Create directory and extract + mkdir -p /opt/actions-runner + cd /opt/actions-runner + tar xzf "$RUNNER_ARCHIVE" + chown -R ubuntu:ubuntu /opt/actions-runner + log_to_cloudwatch "INFO" "Extracted runner from $RUNNER_ARCHIVE" fi # Get instance type (we already have instance ID from earlier) @@ -76,88 +86,36 @@ INSTANCE_TYPE=$(curl -s -H "X-aws-ec2-metadata-token: $TOKEN" http://169.254.169 log_to_cloudwatch "INFO" "Instance: ${INSTANCE_ID}, Type: ${INSTANCE_TYPE}" -# Function to get GitHub registration token with retry -get_github_token() { - local max_attempts=5 - local attempt=1 - local delay=5 - - while [ $attempt -le $max_attempts ]; do - log_to_cloudwatch "INFO" "Attempting to get GitHub registration token (attempt ${attempt}/${max_attempts})" - - GITHUB_TOKEN=$(curl -s -L \ - -X POST \ - -H "Accept: application/vnd.github+json" \ - -H "Authorization: Bearer {{.GitHubPAT}}" \ - -H "X-GitHub-Api-Version: 2022-11-28" \ - https://api.github.com/orgs/frgrisk/actions/runners/registration-token | jq -r .token) - - if [ -n "$GITHUB_TOKEN" ] && [ "$GITHUB_TOKEN" != "null" ]; then - log_to_cloudwatch "INFO" "Successfully obtained GitHub registration token" - return 0 - fi - - log_to_cloudwatch "WARN" "Failed to get GitHub token, retrying in ${delay} seconds..." - sleep $delay - delay=$((delay * 2)) - attempt=$((attempt + 1)) - done - - log_to_cloudwatch "ERROR" "Failed to get GitHub registration token after ${max_attempts} attempts" - return 1 -} +# JIT config is stored in SSM Parameter Store by Lambda (avoids cloud-init caching issues) +# Parameter name is /github-runner/jit-config/{instance-id} +SSM_PARAM_NAME="/github-runner/jit-config/${INSTANCE_ID}" -# Get GitHub registration token -if ! get_github_token; then - log_to_cloudwatch "ERROR" "Unable to proceed without registration token" - shutdown now - exit 1 -fi +log_to_cloudwatch "INFO" "Fetching JIT config from SSM: ${SSM_PARAM_NAME}" -# Configure runner with retry -log_to_cloudwatch "INFO" "Configuring GitHub runner" -max_config_attempts=3 -config_attempt=1 - -while [ $config_attempt -le $max_config_attempts ]; do - if sudo -u ubuntu ./config.sh \ - --url https://github.com/frgrisk \ - --token "$GITHUB_TOKEN" \ - --disableupdate \ - --ephemeral \ - --labels "${INSTANCE_TYPE},ephemeral,X64{{.ExtraLabels}}" \ - --unattended \ - --name "ephemeral-${INSTANCE_ID}" \ - --work _work; then - - log_to_cloudwatch "INFO" "Runner configured successfully" - break - else - log_to_cloudwatch "WARN" "Runner configuration failed (attempt ${config_attempt}/${max_config_attempts})" - config_attempt=$((config_attempt + 1)) - if [ $config_attempt -le $max_config_attempts ]; then - sleep 10 - fi - fi -done +JIT_CONFIG=$(aws ssm get-parameter --name "${SSM_PARAM_NAME}" --with-decryption --query 'Parameter.Value' --output text 2>/dev/null || echo "") -if [ $config_attempt -gt $max_config_attempts ]; then - log_to_cloudwatch "ERROR" "Failed to configure runner after ${max_config_attempts} attempts" +if [ -z "$JIT_CONFIG" ]; then + log_to_cloudwatch "ERROR" "JIT config not found in SSM" shutdown now exit 1 fi +# Delete the parameter after reading (one-time use) +aws ssm delete-parameter --name "${SSM_PARAM_NAME}" 2>/dev/null || true + +log_to_cloudwatch "INFO" "JIT config retrieved from SSM, skipping config.sh" + END_TIME=$(date +%s) EXECUTION_TIME=$((END_TIME - START_TIME)) log_to_cloudwatch "INFO" "Setup completed in ${EXECUTION_TIME} seconds" -# Start the runner and wait for it to complete -log_to_cloudwatch "INFO" "Starting GitHub runner" +# Start the runner with JIT config (skips registration entirely) +log_to_cloudwatch "INFO" "Starting GitHub runner with JIT config" # Create a temporary file to capture runner output RUNNER_LOG=$(mktemp /tmp/runner-output.XXXXXX) -if sudo -u ubuntu ./run.sh 2>&1 | tee "${RUNNER_LOG}"; then +if sudo -u ubuntu ./run.sh --jitconfig "$JIT_CONFIG" 2>&1 | tee "${RUNNER_LOG}"; then log_to_cloudwatch "INFO" "Runner completed successfully" else EXIT_CODE=$?