diff --git a/go.mod b/go.mod index c6cc60d5..147588da 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( cloud.google.com/go/bigquery v1.51.2 cloud.google.com/go/storage v1.30.1 github.com/aws/aws-sdk-go v1.44.285 - github.com/golang-jwt/jwt v3.2.2+incompatible // supports old JWT specifications + github.com/golang-jwt/jwt v3.2.1+incompatible // supports old JWT specifications github.com/golang-migrate/migrate/v4 v4.16.2 github.com/google/uuid v1.3.0 github.com/googleapis/gax-go/v2 v2.11.0 diff --git a/go.sum b/go.sum index 782157c1..593304f2 100644 --- a/go.sum +++ b/go.sum @@ -232,8 +232,8 @@ github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt v3.2.1+incompatible h1:73Z+4BJcrTC+KczS6WvTPvRGOp1WmfEP4Q1lOd9Z/+c= +github.com/golang-jwt/jwt v3.2.1+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-migrate/migrate/v4 v4.16.2 h1:8coYbMKUyInrFk1lfGfRovTLAW7PhWp8qQDT2iKfuoA= github.com/golang-migrate/migrate/v4 v4.16.2/go.mod h1:pfcJX4nPHaVdc5nmdCikFBWtm+UBpiZjRNNsyBbp0/o= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= diff --git a/src/server/user/claims_test.go b/src/server/user/claims_test.go new file mode 100644 index 00000000..e43e8c5a --- /dev/null +++ b/src/server/user/claims_test.go @@ -0,0 +1,90 @@ +package user + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "net/http" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateJWTFromAmazonOIDC(t *testing.T) { + // Generate a new private key + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + // Create a header + header := map[string]string{ + "alg": "ES256", + "kid": "testKid", + } + headerBytes, err := json.Marshal(header) + if err != nil { + t.Fatal(err) + } + + // Create a payload with an email claim + payload := map[string]string{ + "email": "test@example.com", + } + payloadBytes, err := json.Marshal(payload) + if err != nil { + t.Fatal(err) + } + + // Base64 encode the header and payload using old JWT spec, new spec uses base64.RawURLEncoding.EncodeToString + encodedHeader := base64.StdEncoding.EncodeToString(headerBytes) + encodedPayload := base64.StdEncoding.EncodeToString(payloadBytes) + + // Concatenate the encoded header and payload with a period separator + message := encodedHeader + "." + encodedPayload + + // Sign the message with the private key + hash := sha256.Sum256([]byte(message)) + r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash[:]) + if err != nil { + t.Fatal(err) + } + + // Base64 encode the signature + signature := r.Bytes() + signature = append(signature, s.Bytes()...) + encodedSignature := base64.RawURLEncoding.EncodeToString(signature) + + // Concatenate the signature with the message with a period separator + tokenString := message + "." + encodedSignature + + claimsCheck := ClaimsCheck{ + audience: "test-audience", + requireIAP: false, + requireAmazonOIDC: true, + devClaimsEmail: "", + region: "us-east-1", + publicKeys: &sync.Map{}, + } + + // Store the public key + claimsCheck.publicKeys.Store(header["kid"], &privateKey.PublicKey) + + // Mock the http request + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("x-amzn-oidc-data", tokenString) + + ctx := context.Background() + claims := claimsCheck.validateJWTFromAmazonOIDC(ctx, req.Header.Get("x-amzn-oidc-data")) + + assert.NotNil(t, claims) + assert.Equal(t, "test@example.com", claims.Email) +}