diff --git a/cmd/main.go b/cmd/main.go index 3583234..077ff2f 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -6,9 +6,9 @@ import ( "log/slog" "os" - "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" "github.com/spf13/cobra" awsspiffe "github.com/spiffe/aws-spiffe-workload-helper" + "github.com/spiffe/aws-spiffe-workload-helper/internal/vendoredaws" "github.com/spiffe/go-spiffe/v2/workloadapi" ) @@ -95,7 +95,7 @@ func newX509CredentialProcessCmd() (*cobra.Command, error) { if err != nil { return fmt.Errorf("getting signature algorithm: %w", err) } - credentials, err := aws_signing_helper.GenerateCredentials(&aws_signing_helper.CredentialsOpts{ + credentials, err := vendoredaws.GenerateCredentials(&vendoredaws.CredentialsOpts{ RoleArn: roleARN, ProfileArnStr: profileARN, Region: region, diff --git a/internal/vendoredaws/LICENSE b/internal/vendoredaws/LICENSE new file mode 100644 index 0000000..19dc35b --- /dev/null +++ b/internal/vendoredaws/LICENSE @@ -0,0 +1,175 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. \ No newline at end of file diff --git a/internal/vendoredaws/NOTICE b/internal/vendoredaws/NOTICE new file mode 100644 index 0000000..f48b352 --- /dev/null +++ b/internal/vendoredaws/NOTICE @@ -0,0 +1 @@ +Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. \ No newline at end of file diff --git a/internal/vendoredaws/README.md b/internal/vendoredaws/README.md new file mode 100644 index 0000000..f4cd11d --- /dev/null +++ b/internal/vendoredaws/README.md @@ -0,0 +1,12 @@ +The code within this package is a partial vendoring of +https://github.com/aws/rolesanywhere-credential-helper/tree/main/aws_signing_helper + +The original source is licensed under Apache 2.0, this license can be found in +`LICENSE`. + +This code was vendored to break the dependency of `aws_signing_package` on +https://github.com/miekg/pkcs11, which requires CGO to build. + +An issue is open with the upstream repository to break apart the packages to +avoid this dependency, at which point this vendoring will be obselete: +https://github.com/aws/rolesanywhere-credential-helper/issues/86 \ No newline at end of file diff --git a/internal/vendoredaws/credentials.go b/internal/vendoredaws/credentials.go new file mode 100644 index 0000000..5fa4f59 --- /dev/null +++ b/internal/vendoredaws/credentials.go @@ -0,0 +1,133 @@ +package vendoredaws + +import ( + "crypto/tls" + "encoding/base64" + "errors" + "log" + "net/http" + "runtime" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" +) + +type CredentialsOpts struct { + PrivateKeyId string + CertificateId string + CertificateBundleId string + CertIdentifier CertIdentifier + RoleArn string + ProfileArnStr string + TrustAnchorArnStr string + SessionDuration int + Region string + Endpoint string + NoVerifySSL bool + WithProxy bool + Debug bool + Version string + LibPkcs11 string + ReusePin bool + ServerTTL int + RoleSessionName string +} + +// Function to create session and generate credentials +func GenerateCredentials(opts *CredentialsOpts, signer Signer, signatureAlgorithm string) (CredentialProcessOutput, error) { + // Assign values to region and endpoint if they haven't already been assigned + trustAnchorArn, err := arn.Parse(opts.TrustAnchorArnStr) + if err != nil { + return CredentialProcessOutput{}, err + } + profileArn, err := arn.Parse(opts.ProfileArnStr) + if err != nil { + return CredentialProcessOutput{}, err + } + + if trustAnchorArn.Region != profileArn.Region { + return CredentialProcessOutput{}, errors.New("trust anchor and profile regions don't match") + } + + if opts.Region == "" { + opts.Region = trustAnchorArn.Region + } + + mySession := session.Must(session.NewSession()) + + var logLevel aws.LogLevelType + if Debug { + logLevel = aws.LogDebug + } else { + logLevel = aws.LogOff + } + + var tr *http.Transport + if opts.WithProxy { + tr = &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: opts.NoVerifySSL}, + Proxy: http.ProxyFromEnvironment, + } + } else { + tr = &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: opts.NoVerifySSL}, + } + } + client := &http.Client{Transport: tr} + config := aws.NewConfig().WithRegion(opts.Region).WithHTTPClient(client).WithLogLevel(logLevel) + if opts.Endpoint != "" { + config.WithEndpoint(opts.Endpoint) + } + rolesAnywhereClient := rolesanywhere.New(mySession, config) + rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") + rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: request.MakeAddToUserAgentHandler("CredHelper", opts.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH)}) + rolesAnywhereClient.Handlers.Sign.Clear() + certificate, err := signer.Certificate() + if err != nil { + return CredentialProcessOutput{}, errors.New("unable to find certificate") + } + certificateChain, err := signer.CertificateChain() + if err != nil { + // If the chain couldn't be found, don't include it in the request + if Debug { + log.Println(err) + } + } + rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: CreateRequestSignFunction(signer, signatureAlgorithm, certificate, certificateChain)}) + + certificateStr := base64.StdEncoding.EncodeToString(certificate.Raw) + durationSeconds := int64(opts.SessionDuration) + createSessionRequest := rolesanywhere.CreateSessionInput{ + Cert: &certificateStr, + ProfileArn: &opts.ProfileArnStr, + TrustAnchorArn: &opts.TrustAnchorArnStr, + DurationSeconds: &(durationSeconds), + InstanceProperties: nil, + RoleArn: &opts.RoleArn, + SessionName: nil, + } + if opts.RoleSessionName != "" { + createSessionRequest.RoleSessionName = &opts.RoleSessionName + } + output, err := rolesAnywhereClient.CreateSession(&createSessionRequest) + if err != nil { + return CredentialProcessOutput{}, err + } + + if len(output.CredentialSet) == 0 { + msg := "unable to obtain temporary security credentials from CreateSession" + return CredentialProcessOutput{}, errors.New(msg) + } + credentials := output.CredentialSet[0].Credentials + credentialProcessOutput := CredentialProcessOutput{ + Version: 1, + AccessKeyId: *credentials.AccessKeyId, + SecretAccessKey: *credentials.SecretAccessKey, + SessionToken: *credentials.SessionToken, + Expiration: *credentials.Expiration, + } + return credentialProcessOutput, nil +} diff --git a/internal/vendoredaws/signer.go b/internal/vendoredaws/signer.go new file mode 100644 index 0000000..d66e453 --- /dev/null +++ b/internal/vendoredaws/signer.go @@ -0,0 +1,672 @@ +package vendoredaws + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/asn1" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "errors" + "fmt" + "io" + "log" + "math/big" + "net/http" + "os" + "sort" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "golang.org/x/crypto/pkcs12" +) + +type SignerParams struct { + OverriddenDate time.Time + RegionName string + ServiceName string + SigningAlgorithm string +} + +type CertIdentifier struct { + Subject string + Issuer string + SerialNumber *big.Int + SystemStoreName string // Only relevant in the case of Windows +} + +var ( + // ErrUnsupportedHash is returned by Signer.Sign() when the provided hash + // algorithm isn't supported. + ErrUnsupportedHash = errors.New("unsupported hash algorithm") + + // Predefined system store names. + // See: https://learn.microsoft.com/en-us/windows/win32/seccrypto/system-store-locations + SystemStoreNames = []string{ + "MY", + "Root", + "Trust", + "CA", + } +) + +// Interface that all signers will have to implement +// (as a result, they will also implement crypto.Signer) +type Signer interface { + Public() crypto.PublicKey + Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) + Certificate() (certificate *x509.Certificate, err error) + CertificateChain() (certificateChain []*x509.Certificate, err error) + Close() +} + +// Container for certificate data returned to the SDK as JSON. +type CertificateData struct { + // Type for the key contained in the certificate. + // Passed back to the `sign-string` command + KeyType string `json:"keyType"` + // Certificate, as base64-encoded DER; used in the `x-amz-x509` + // header in the API request. + CertificateData string `json:"certificateData"` + // Serial number of the certificate. Used in the credential + // field of the Authorization header + SerialNumber string `json:"serialNumber"` + // Supported signing algorithms based on the KeyType + Algorithms []string `json:"supportedAlgorithms"` +} + +// Container that adheres to the format of credential_process output as specified by AWS. +type CredentialProcessOutput struct { + // This field should be hard-coded to 1 for now. + Version int `json:"Version"` + // AWS Access Key ID + AccessKeyId string `json:"AccessKeyId"` + // AWS Secret Access Key + SecretAccessKey string `json:"SecretAccessKey"` + // AWS Session Token for temporary credentials + SessionToken string `json:"SessionToken"` + // ISO8601 timestamp for when the credentials expire + Expiration string `json:"Expiration"` +} + +type CertificateContainer struct { + // Certificate data + Cert *x509.Certificate + // Certificate URI (only populated in the case that the certificate is a PKCS#11 object) + Uri string +} + +// Define constants used in signing +const ( + aws4_x509_rsa_sha256 = "AWS4-X509-RSA-SHA256" + aws4_x509_ecdsa_sha256 = "AWS4-X509-ECDSA-SHA256" + timeFormat = "20060102T150405Z" + shortTimeFormat = "20060102" + x_amz_date = "X-Amz-Date" + x_amz_x509 = "X-Amz-X509" + x_amz_x509_chain = "X-Amz-X509-Chain" + x_amz_content_sha256 = "X-Amz-Content-Sha256" + authorization = "Authorization" + host = "Host" + emptyStringSHA256 = `e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855` +) + +// Headers that aren't included in calculating the signature +var ignoredHeaderKeys = map[string]bool{ + "Authorization": true, + "User-Agent": true, + "X-Amzn-Trace-Id": true, +} + +var Debug bool = false + +// Find whether the current certificate matches the CertIdentifier +func certMatches(certIdentifier CertIdentifier, cert x509.Certificate) bool { + if certIdentifier.Subject != "" && certIdentifier.Subject != cert.Subject.String() { + return false + } + if certIdentifier.Issuer != "" && certIdentifier.Issuer != cert.Issuer.String() { + return false + } + if certIdentifier.SerialNumber != nil && certIdentifier.SerialNumber.Cmp(cert.SerialNumber) != 0 { + return false + } + + return true +} + +// Because of *course* we have to do this for ourselves. +// +// Create the DER-encoded SEQUENCE containing R and S: +// +// Ecdsa-Sig-Value ::= SEQUENCE { +// r INTEGER, +// s INTEGER +// } +// +// This is defined in RFC3279 ยง2.2.3 as well as SEC.1. +// I can't find anything which mandates DER but I've seen +// OpenSSL refusing to verify it with indeterminate length. +func encodeEcdsaSigValue(signature []byte) (out []byte, err error) { + sigLen := len(signature) / 2 + + return asn1.Marshal(struct { + R *big.Int + S *big.Int + }{ + big.NewInt(0).SetBytes(signature[:sigLen]), + big.NewInt(0).SetBytes(signature[sigLen:])}) +} + +// Obtain the date-time, formatted as specified by SigV4 +func (signerParams *SignerParams) GetFormattedSigningDateTime() string { + return signerParams.OverriddenDate.UTC().Format(timeFormat) +} + +// Obtain the short date-time, formatted as specified by SigV4 +func (signerParams *SignerParams) GetFormattedShortSigningDateTime() string { + return signerParams.OverriddenDate.UTC().Format(shortTimeFormat) +} + +// Obtain the scope as part of the SigV4-X509 signature +func (signerParams *SignerParams) GetScope() string { + var scopeStringBuilder strings.Builder + scopeStringBuilder.WriteString(signerParams.GetFormattedShortSigningDateTime()) + scopeStringBuilder.WriteString("/") + scopeStringBuilder.WriteString(signerParams.RegionName) + scopeStringBuilder.WriteString("/") + scopeStringBuilder.WriteString(signerParams.ServiceName) + scopeStringBuilder.WriteString("/") + scopeStringBuilder.WriteString("aws4_request") + return scopeStringBuilder.String() +} + +// Convert certificate to string, so that it can be present in the HTTP request header +func certificateToString(certificate *x509.Certificate) string { + return base64.StdEncoding.EncodeToString(certificate.Raw) +} + +// Convert certificate chain to string, so that it can be pressent in the HTTP request header +func certificateChainToString(certificateChain []*x509.Certificate) string { + var x509ChainString strings.Builder + for i, certificate := range certificateChain { + x509ChainString.WriteString(certificateToString(certificate)) + if i != len(certificateChain)-1 { + x509ChainString.WriteString(",") + } + } + return x509ChainString.String() +} + +func CreateRequestSignFunction(signer crypto.Signer, signingAlgorithm string, certificate *x509.Certificate, certificateChain []*x509.Certificate) func(*request.Request) { + return func(req *request.Request) { + region := req.ClientInfo.SigningRegion + if region == "" { + region = aws.StringValue(req.Config.Region) + } + + name := req.ClientInfo.SigningName + if name == "" { + name = req.ClientInfo.ServiceName + } + + signerParams := SignerParams{time.Now(), region, name, signingAlgorithm} + + // Set headers that are necessary for signing + req.HTTPRequest.Header.Set(host, req.HTTPRequest.URL.Host) + req.HTTPRequest.Header.Set(x_amz_date, signerParams.GetFormattedSigningDateTime()) + req.HTTPRequest.Header.Set(x_amz_x509, certificateToString(certificate)) + if certificateChain != nil { + req.HTTPRequest.Header.Set(x_amz_x509_chain, certificateChainToString(certificateChain)) + } + + contentSha256 := calculateContentHash(req.HTTPRequest, req.Body) + if req.HTTPRequest.Header.Get(x_amz_content_sha256) == "required" { + req.HTTPRequest.Header.Set(x_amz_content_sha256, contentSha256) + } + + canonicalRequest, signedHeadersString := createCanonicalRequest(req.HTTPRequest, req.Body, contentSha256) + + stringToSign := CreateStringToSign(canonicalRequest, signerParams) + signatureBytes, err := signer.Sign(rand.Reader, []byte(stringToSign), crypto.SHA256) + if err != nil { + log.Println(err.Error()) + os.Exit(1) + } + signature := hex.EncodeToString(signatureBytes) + + req.HTTPRequest.Header.Set(authorization, BuildAuthorizationHeader(req.HTTPRequest, req.Body, signedHeadersString, signature, certificate, signerParams)) + req.SignedHeaderVals = req.HTTPRequest.Header + } +} + +// Find the SHA256 hash of the provided request body as a io.ReadSeeker +func makeSha256Reader(reader io.ReadSeeker) []byte { + hash := sha256.New() + start, _ := reader.Seek(0, 1) + defer reader.Seek(start, 0) + + io.Copy(hash, reader) + return hash.Sum(nil) +} + +// Calculate the hash of the request body +func calculateContentHash(r *http.Request, body io.ReadSeeker) string { + hash := r.Header.Get(x_amz_content_sha256) + + if hash == "" { + if body == nil { + hash = emptyStringSHA256 + } else { + hash = hex.EncodeToString(makeSha256Reader(body)) + } + } + + return hash +} + +// Create the canonical query string. +func createCanonicalQueryString(r *http.Request, body io.ReadSeeker) string { + rawQuery := strings.Replace(r.URL.Query().Encode(), "+", "%20", -1) + return rawQuery +} + +// Create the canonical header string. +func createCanonicalHeaderString(r *http.Request) (string, string) { + var headers []string + signedHeaderVals := make(http.Header) + for k, v := range r.Header { + canonicalKey := http.CanonicalHeaderKey(k) + if ignoredHeaderKeys[canonicalKey] { + continue + } + + lowerCaseKey := strings.ToLower(k) + if _, ok := signedHeaderVals[lowerCaseKey]; ok { + // include additional values + signedHeaderVals[lowerCaseKey] = append(signedHeaderVals[lowerCaseKey], v...) + continue + } + + headers = append(headers, lowerCaseKey) + signedHeaderVals[lowerCaseKey] = v + } + sort.Strings(headers) + + headerValues := make([]string, len(headers)) + for i, k := range headers { + headerValues[i] = k + ":" + strings.Join(signedHeaderVals[k], ",") + } + stripExcessSpaces(headerValues) + return strings.Join(headerValues, "\n"), strings.Join(headers, ";") +} + +const doubleSpace = " " + +// stripExcessSpaces will rewrite the passed in slice's string values to not +// contain muliple side-by-side spaces. +func stripExcessSpaces(vals []string) { + var j, k, l, m, spaces int + for i, str := range vals { + // Trim trailing spaces + for j = len(str) - 1; j >= 0 && str[j] == ' '; j-- { + } + + // Trim leading spaces + for k = 0; k < j && str[k] == ' '; k++ { + } + str = str[k : j+1] + + // Strip multiple spaces. + j = strings.Index(str, doubleSpace) + if j < 0 { + vals[i] = str + continue + } + + buf := []byte(str) + for k, m, l = j, j, len(buf); k < l; k++ { + if buf[k] == ' ' { + if spaces == 0 { + // First space. + buf[m] = buf[k] + m++ + } + spaces++ + } else { + // End of multiple spaces. + spaces = 0 + buf[m] = buf[k] + m++ + } + } + + vals[i] = string(buf[:m]) + } +} + +// Create the canonical request. +func createCanonicalRequest(r *http.Request, body io.ReadSeeker, contentSha256 string) (string, string) { + var canonicalRequestStrBuilder strings.Builder + canonicalHeaderString, signedHeadersString := createCanonicalHeaderString(r) + canonicalRequestStrBuilder.WriteString("POST") + canonicalRequestStrBuilder.WriteString("\n") + canonicalRequestStrBuilder.WriteString("/sessions") + canonicalRequestStrBuilder.WriteString("\n") + canonicalRequestStrBuilder.WriteString(createCanonicalQueryString(r, body)) + canonicalRequestStrBuilder.WriteString("\n") + canonicalRequestStrBuilder.WriteString(canonicalHeaderString) + canonicalRequestStrBuilder.WriteString("\n\n") + canonicalRequestStrBuilder.WriteString(signedHeadersString) + canonicalRequestStrBuilder.WriteString("\n") + canonicalRequestStrBuilder.WriteString(contentSha256) + canonicalRequestString := canonicalRequestStrBuilder.String() + canonicalRequestStringHashBytes := sha256.Sum256([]byte(canonicalRequestString)) + return hex.EncodeToString(canonicalRequestStringHashBytes[:]), signedHeadersString +} + +// Create the string to sign. +func CreateStringToSign(canonicalRequest string, signerParams SignerParams) string { + var stringToSignStrBuilder strings.Builder + stringToSignStrBuilder.WriteString(signerParams.SigningAlgorithm) + stringToSignStrBuilder.WriteString("\n") + stringToSignStrBuilder.WriteString(signerParams.GetFormattedSigningDateTime()) + stringToSignStrBuilder.WriteString("\n") + stringToSignStrBuilder.WriteString(signerParams.GetScope()) + stringToSignStrBuilder.WriteString("\n") + stringToSignStrBuilder.WriteString(canonicalRequest) + stringToSign := stringToSignStrBuilder.String() + return stringToSign +} + +// Builds the complete authorization header +func BuildAuthorizationHeader(request *http.Request, body io.ReadSeeker, signedHeadersString string, signature string, certificate *x509.Certificate, signerParams SignerParams) string { + signingCredentials := certificate.SerialNumber.String() + "/" + signerParams.GetScope() + credential := "Credential=" + signingCredentials + signerHeaders := "SignedHeaders=" + signedHeadersString + signatureHeader := "Signature=" + signature + + var authHeaderStringBuilder strings.Builder + authHeaderStringBuilder.WriteString(signerParams.SigningAlgorithm) + authHeaderStringBuilder.WriteString(" ") + authHeaderStringBuilder.WriteString(credential) + authHeaderStringBuilder.WriteString(", ") + authHeaderStringBuilder.WriteString(signerHeaders) + authHeaderStringBuilder.WriteString(", ") + authHeaderStringBuilder.WriteString(signatureHeader) + authHeaderString := authHeaderStringBuilder.String() + return authHeaderString +} + +func encodeDer(der []byte) (string, error) { + var buf bytes.Buffer + encoder := base64.NewEncoder(base64.StdEncoding, &buf) + encoder.Write(der) + encoder.Close() + return buf.String(), nil +} + +func parseDERFromPEM(pemDataId string, blockType string) (*pem.Block, error) { + bytes, err := os.ReadFile(pemDataId) + if err != nil { + log.Println(err) + return nil, err + } + + var block *pem.Block + for len(bytes) > 0 { + block, bytes = pem.Decode(bytes) + if block == nil { + return nil, errors.New("unable to parse PEM data") + } + if block.Type == blockType { + return block, nil + } + } + return nil, errors.New("requested block type could not be found") +} + +// Reads certificate bundle data from a file, whose path is provided +func ReadCertificateBundleData(certificateBundleId string) ([]*x509.Certificate, error) { + bytes, err := os.ReadFile(certificateBundleId) + if err != nil { + log.Println(err) + return nil, err + } + + var derBytes []byte + var block *pem.Block + for len(bytes) > 0 { + block, bytes = pem.Decode(bytes) + if block == nil { + break + } + if block.Type != "CERTIFICATE" { + return nil, errors.New("invalid certificate chain") + } + blockBytes := block.Bytes + derBytes = append(derBytes, blockBytes...) + } + + return x509.ParseCertificates(derBytes) +} + +func readECPrivateKey(privateKeyId string) (ecdsa.PrivateKey, error) { + block, err := parseDERFromPEM(privateKeyId, "EC PRIVATE KEY") + if err != nil { + return ecdsa.PrivateKey{}, errors.New("could not parse PEM data") + } + + privateKey, err := x509.ParseECPrivateKey(block.Bytes) + if err != nil { + return ecdsa.PrivateKey{}, errors.New("could not parse private key") + } + + return *privateKey, nil +} + +func readRSAPrivateKey(privateKeyId string) (rsa.PrivateKey, error) { + block, err := parseDERFromPEM(privateKeyId, "RSA PRIVATE KEY") + if err != nil { + return rsa.PrivateKey{}, errors.New("could not parse PEM data") + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return rsa.PrivateKey{}, errors.New("could not parse private key") + } + + return *privateKey, nil +} + +func readPKCS8PrivateKey(privateKeyId string) (crypto.PrivateKey, error) { + block, err := parseDERFromPEM(privateKeyId, "PRIVATE KEY") + if err != nil { + return nil, errors.New("could not parse PEM data") + } + + privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, errors.New("could not parse private key") + } + + rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey) + if ok { + return *rsaPrivateKey, nil + } + + ecPrivateKey, ok := privateKey.(*ecdsa.PrivateKey) + if ok { + return *ecPrivateKey, nil + } + + return nil, errors.New("could not parse PKCS#8 private key") +} + +// Reads and parses a PKCS#12 file (which should contain an end-entity +// certificate, (optional) certificate chain, and the key associated with the +// end-entity certificate). The end-entity certificate will be the first +// certificate in the returned chain. This method assumes that there is +// exactly one certificate that doesn't issue any others within the container +// and treats that as the end-entity certificate. Also, the order of the other +// certificates in the chain aren't guaranteed (it's also not guaranteed that +// those certificates form a chain with the end-entity certificat either). +func ReadPKCS12Data(certificateId string) (certChain []*x509.Certificate, privateKey crypto.PrivateKey, err error) { + var ( + bytes []byte + pemBlocks []*pem.Block + parsedCerts []*x509.Certificate + certMap map[string]*x509.Certificate + endEntityFoundIndex int + ) + + bytes, err = os.ReadFile(certificateId) + if err != nil { + return nil, nil, nil + } + + pemBlocks, err = pkcs12.ToPEM(bytes, "") + if err != nil { + return nil, "", err + } + + for _, block := range pemBlocks { + cert, err := x509.ParseCertificate(block.Bytes) + if err == nil { + parsedCerts = append(parsedCerts, cert) + continue + } + privateKeyTmp, err := ReadPrivateKeyDataFromPEMBlock(block) + if err == nil { + privateKey = privateKeyTmp + continue + } + // If neither a certificate nor a private key could be parsed from the + // Block, ignore it and continue. + if Debug { + log.Println("unable to parse PEM block in PKCS#12 file - skipping") + } + } + + certMap = make(map[string]*x509.Certificate) + for _, cert := range parsedCerts { + // pkix.Name.String() roughly following the RFC 2253 Distinguished Names + // syntax, so we assume that it's canonical. + issuer := cert.Issuer.String() + certMap[issuer] = cert + } + + endEntityFoundIndex = -1 + for i, cert := range parsedCerts { + subject := cert.Subject.String() + if _, ok := certMap[subject]; !ok { + certChain = append(certChain, cert) + endEntityFoundIndex = i + break + } + } + if endEntityFoundIndex == -1 { + return nil, "", errors.New("no end-entity certificate found in PKCS#12 file") + } + + for i, cert := range parsedCerts { + if i != endEntityFoundIndex { + certChain = append(certChain, cert) + } + } + + return certChain, privateKey, nil +} + +// Load the private key referenced by `privateKeyId`. +func ReadPrivateKeyData(privateKeyId string) (crypto.PrivateKey, error) { + if key, err := readPKCS8PrivateKey(privateKeyId); err == nil { + return key, nil + } + + if key, err := readECPrivateKey(privateKeyId); err == nil { + return key, nil + } + + if key, err := readRSAPrivateKey(privateKeyId); err == nil { + return key, nil + } + + return nil, errors.New("unable to parse private key") +} + +// Reads private key data from a *pem.Block. +func ReadPrivateKeyDataFromPEMBlock(block *pem.Block) (key crypto.PrivateKey, err error) { + key, err = x509.ParseECPrivateKey(block.Bytes) + if err == nil { + return key, nil + } + + key, err = x509.ParsePKCS1PrivateKey(block.Bytes) + if err == nil { + return key, nil + } + + return nil, errors.New("unable to parse private key") +} + +// ReadCertificateData loads the certificate referenced by `certificateId` and extracts +// details required by the SDK to construct the StringToSign. +func ReadCertificateData(certificateId string) (CertificateData, *x509.Certificate, error) { + block, err := parseDERFromPEM(certificateId, "CERTIFICATE") + if err != nil { + return CertificateData{}, nil, errors.New("could not parse PEM data") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + log.Println("could not parse certificate", err) + return CertificateData{}, nil, errors.New("could not parse certificate") + } + + //extract serial number + serialNumber := cert.SerialNumber.String() + + //encode certificate + encodedDer, _ := encodeDer(block.Bytes) + + //extract key type + var keyType string + switch cert.PublicKeyAlgorithm { + case x509.RSA: + keyType = "RSA" + case x509.ECDSA: + keyType = "EC" + default: + keyType = "" + } + + supportedAlgorithms := []string{ + fmt.Sprintf("%sSHA256", keyType), + fmt.Sprintf("%sSHA384", keyType), + fmt.Sprintf("%sSHA512", keyType), + } + + //return struct + return CertificateData{keyType, encodedDer, serialNumber, supportedAlgorithms}, cert, nil +} + +// GetCertChain reads a certificate bundle and returns a chain of all the certificates it contains +func GetCertChain(certificateBundleId string) ([]*x509.Certificate, error) { + certificateChainPointers, err := ReadCertificateBundleData(certificateBundleId) + var chain []*x509.Certificate + if err != nil { + return nil, err + } + for _, certificate := range certificateChainPointers { + chain = append(chain, certificate) + } + return chain, nil +}