Skip to content

Commit

Permalink
Allow passing payloads to rollout route
Browse files Browse the repository at this point in the history
  • Loading branch information
joecorall committed Apr 18, 2024
1 parent 5425d12 commit afa9bbb
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 17 deletions.
69 changes: 65 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,26 @@ import (
"net/http"
"os"
"os/exec"
"reflect"

"github.com/golang-jwt/jwt/v5"
"github.com/google/shlex"
"github.com/lestrrat-go/jwx/jwk"
)

type RolloutPayload struct {
DockerImage string `json:"docker-image" env:"DOCKER_IMAGE"`
DockerTag string `json:"docker-tag" env:"DOCKER_TAG"`
GitRepo string `json:"git-repo" env:"GIT_REPO"`
GitBranch string `json:"git-branch" env:"GIT_BRANCH"`
Arg1 string `json:"rollout-arg1" env:"ROLLOUT_ARG1"`
Arg2 string `json:"rollout-arg2" env:"ROLLOUT_ARG2"`
Arg3 string `json:"rollout-arg3" env:"ROLLOUT_ARG3"`
}

func init() {
// call getArgs early to fail on a bad config
getArgs()
// call getRolloutCmdArgs early to fail on a bad config
getRolloutCmdArgs()
}

func main() {
Expand Down Expand Up @@ -88,11 +99,18 @@ func Rollout(w http.ResponseWriter, r *http.Request) {
return
}

err = setCustomArgs(r)
if err != nil {
slog.Error("Error setting custom logs", "err", err)
http.Error(w, "Script execution failed", http.StatusInternalServerError)
return
}

name := os.Getenv("ROLLOUT_CMD")
if name == "" {
name = "/bin/bash"
}
cmd := exec.Command(name, getArgs()...)
cmd := exec.Command(name, getRolloutCmdArgs()...)

var stdOut, stdErr bytes.Buffer
cmd.Stdout = &stdOut
Expand Down Expand Up @@ -166,7 +184,7 @@ func strInSlice(e string, s []string) bool {
return false
}

func getArgs() []string {
func getRolloutCmdArgs() []string {
args := os.Getenv("ROLLOUT_ARGS")
if args == "" {
args = "/rollout.sh"
Expand All @@ -179,3 +197,46 @@ func getArgs() []string {

return rolloutArgs
}

func setCustomArgs(r *http.Request) error {
if r.Method == "GET" {
return nil
}

var payload RolloutPayload
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&payload)
if err != nil {
return err
}

err = setEnvFromStruct(&payload)
if err != nil {
return err
}

return nil
}

func setEnvFromStruct(data interface{}) error {
v := reflect.ValueOf(data)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}

t := v.Type()
for i := 0; i < v.NumField(); i++ {
field := t.Field(i)
if envTag, ok := field.Tag.Lookup("env"); ok {
// For now all fields are strings
value := v.Field(i).String()
if value == "" {
continue
}
if err := os.Setenv(envTag, value); err != nil {
return fmt.Errorf("could not set environment variable %s: %v", envTag, err)
}
}
}
return nil
}
110 changes: 97 additions & 13 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log/slog"
"math/big"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"

Expand All @@ -23,6 +25,16 @@ var (
privateKey *rsa.PrivateKey
)

type Test struct {
name string
authHeader string
expectedStatus int
expectedBody string
cmdArgs string
method string
payload string
}

// createJWKS creates a JWKS JSON representation with a single RSA key.
func mockJWKS(pub *rsa.PublicKey, kid string) (string, error) {
jwks := struct {
Expand Down Expand Up @@ -130,14 +142,14 @@ func CreateSignedJWT(kid, aud, claim string, exp int64, privateKey *rsa.PrivateK
}

// Utility function to create a request with an Authorization header
func createRequest(authHeader string) *http.Request {
req, _ := http.NewRequest("GET", "/", nil)
func createRequest(authHeader, method string, body io.Reader) *http.Request {
req, _ := http.NewRequest(method, "/", body)
req.Header.Set("Authorization", authHeader)
return req
}

// TestRollout tests the Rollout function with various scenarios
func TestRollout(t *testing.T) {
func TestRolloutAuth(t *testing.T) {
testFile := "/tmp/rollout-test.txt"

// have our test rollout cmd just touch a file
Expand Down Expand Up @@ -196,14 +208,7 @@ func TestRollout(t *testing.T) {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}

tests := []struct {
name string
authHeader string
expectedStatus int
expectedBody string
claim map[string]string
cmdArgs string
}{
tests := []Test{
{
name: "No Authorization Header",
authHeader: "",
Expand Down Expand Up @@ -269,7 +274,7 @@ func TestRollout(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
request := createRequest(tt.authHeader)
request := createRequest(tt.authHeader, "GET", nil)
if tt.name == "No custom claim" {
os.Setenv("CUSTOM_CLAIMS", "")
} else {
Expand All @@ -285,7 +290,6 @@ func TestRollout(t *testing.T) {
assert.Equal(t, tt.expectedBody, recorder.Body.String())
})
}

testFiles := []string{
testFile,
"/tmp/rollout-shlex-test",
Expand All @@ -308,6 +312,86 @@ func TestRollout(t *testing.T) {
}
}

func TestRolloutCmdArgs(t *testing.T) {
os.Setenv("ROLLOUT_CMD", "/bin/bash")
s := createMockJwksServer()
defer s.Close()

// get a valid token
exp := time.Now().Add(time.Hour * 1).Unix()
jwtToken, err := CreateSignedJWT(kid, aud, claim, exp, privateKey)
if err != nil {
t.Fatalf("Unable to create a JWT with our test key: %v", err)
}

payloads := map[string]string{
"docker-image": "rollout-docker-image-test",
"docker-tag": "rollout-docker-tag-test",
"git-branch": "rollout-git-branch-test",
"git-repo": "rollout-git-repo-test",
"rollout-arg1": "rollout-arg1-test",
"rollout-arg2": "rollout-arg2-test",
"rollout-arg3": "rollout-arg3-test",
}
for k, v := range payloads {
var e string
switch k {
case "docker-image":
e = "DOCKER_IMAGE"
case "docker-tag":
e = "DOCKER_TAG"
case "git-branch":
e = "GIT_BRANCH"
case "git-repo":
e = "GIT_REPO"
case "rollout-arg1":
e = "ROLLOUT_ARG1"
case "rollout-arg2":
e = "ROLLOUT_ARG2"
case "rollout-arg3":
e = "ROLLOUT_ARG3"
}
tt := Test{
name: fmt.Sprintf("%s custom arg passes to rollout.sh", k),
authHeader: "Bearer " + jwtToken,
expectedStatus: http.StatusOK,
cmdArgs: fmt.Sprintf(`-c "touch /tmp/$%s"`, e),
method: "POST",
payload: fmt.Sprintf(`{"%s": "%s"}`, k, v),
expectedBody: "Rollout complete\n",
}
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
method := "POST"
body := strings.NewReader(tt.payload)
request := createRequest(tt.authHeader, method, body)
os.Setenv("ROLLOUT_ARGS", tt.cmdArgs)

Rollout(recorder, request)

assert.Equal(t, tt.expectedStatus, recorder.Code)
assert.Equal(t, tt.expectedBody, recorder.Body.String())
})
}

for _, v := range payloads {
f := "/tmp/" + v
// make sure the rollout command actually ran the command
// which creates the file
_, err = os.Stat(f)
if err != nil && os.IsNotExist(err) {
t.Errorf("The successful test did not create the expected file %s", f)
}

// cleanup
err = RemoveFileIfExists(f)
if err != nil {
slog.Error("Unable to cleanup test file", "file", f, "err", err)
os.Exit(1)
}
}
}

func RemoveFileIfExists(filePath string) error {
_, err := os.Stat(filePath)
if err == nil {
Expand Down

0 comments on commit afa9bbb

Please sign in to comment.