Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ECDSA in Service Providers #586

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions samlsp/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@
package samlsp

import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"net/http"
"net/url"

dsig "github.com/russellhaering/goxmldsig"

"github.com/crewjam/saml"
"github.com/golang-jwt/jwt/v4"
)

// Options represents the parameters for creating a new middleware
type Options struct {
EntityID string
URL url.URL
Key *rsa.PrivateKey
Key crypto.Signer
Certificate *x509.Certificate
Intermediates []*x509.Certificate
HTTPClient *http.Client
Expand All @@ -33,11 +34,23 @@ type Options struct {
LogoutBindings []string
}

func getDefaultSigningMethod(signer crypto.Signer) jwt.SigningMethod {
if signer != nil {
switch signer.Public().(type) {
case *ecdsa.PublicKey:
return jwt.SigningMethodES256
case *rsa.PublicKey:
return jwt.SigningMethodRS256
}
}
return jwt.SigningMethodRS256
}

// DefaultSessionCodec returns the default SessionCodec for the provided options,
// a JWTSessionCodec configured to issue signed tokens.
func DefaultSessionCodec(opts Options) JWTSessionCodec {
return JWTSessionCodec{
SigningMethod: defaultJWTSigningMethod,
SigningMethod: getDefaultSigningMethod(opts.Key),
Audience: opts.URL.String(),
Issuer: opts.URL.String(),
MaxAge: defaultSessionMaxAge,
Expand Down Expand Up @@ -67,7 +80,7 @@ func DefaultSessionProvider(opts Options) CookieSessionProvider {
// options, a JWTTrackedRequestCodec that uses a JWT to encode TrackedRequests.
func DefaultTrackedRequestCodec(opts Options) JWTTrackedRequestCodec {
return JWTTrackedRequestCodec{
SigningMethod: defaultJWTSigningMethod,
SigningMethod: getDefaultSigningMethod(opts.Key),
Audience: opts.URL.String(),
Issuer: opts.URL.String(),
MaxAge: saml.MaxIssueDelay,
Expand Down Expand Up @@ -99,9 +112,9 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider {
if opts.ForceAuthn {
forceAuthn = &opts.ForceAuthn
}
signatureMethod := dsig.RSASHA1SignatureMethod
if !opts.SignRequest {
signatureMethod = ""
var signatureMethod string
if opts.SignRequest {
signatureMethod = "auto"
}

if opts.DefaultRedirectURI == "" {
Expand Down
6 changes: 2 additions & 4 deletions samlsp/request_tracker_jwt.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package samlsp

import (
"crypto/rsa"
"crypto"
"fmt"
"time"

Expand All @@ -10,15 +10,13 @@
"github.com/crewjam/saml"
)

var defaultJWTSigningMethod = jwt.SigningMethodRS256

// JWTTrackedRequestCodec encodes TrackedRequests as signed JWTs
type JWTTrackedRequestCodec struct {
SigningMethod jwt.SigningMethod

Check failure on line 15 in samlsp/request_tracker_jwt.go

View workflow job for this annotation

GitHub Actions / golangci

undefined: jwt (typecheck)
Audience string
Issuer string
MaxAge time.Duration
Key *rsa.PrivateKey
Key crypto.Signer
}

var _ TrackedRequestCodec = JWTTrackedRequestCodec{}
Expand Down Expand Up @@ -61,7 +59,7 @@
if err != nil {
return nil, err
}
if !claims.VerifyAudience(s.Audience, true) {

Check failure on line 62 in samlsp/request_tracker_jwt.go

View workflow job for this annotation

GitHub Actions / golangci

claims.VerifyAudience undefined (type JWTTrackedRequestClaims has no field or method VerifyAudience) (typecheck)
return nil, fmt.Errorf("expected audience %q, got %q", s.Audience, claims.Audience)
}
if !claims.VerifyIssuer(s.Issuer, true) {
Expand Down
4 changes: 2 additions & 2 deletions samlsp/session_jwt.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package samlsp

import (
"crypto/rsa"
"crypto"
"errors"
"fmt"
"time"
Expand All @@ -23,7 +23,7 @@ type JWTSessionCodec struct {
Audience string
Issuer string
MaxAge time.Duration
Key *rsa.PrivateKey
Key crypto.Signer
}

var _ SessionCodec = JWTSessionCodec{}
Expand Down
77 changes: 49 additions & 28 deletions service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"bytes"
"compress/flate"
"context"
"crypto/rsa"
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/base64"
Expand Down Expand Up @@ -67,7 +67,7 @@ type ServiceProvider struct {
EntityID string

// Key is the RSA private key we use to sign requests.
Key *rsa.PrivateKey
Key crypto.Signer

// Certificate is the RSA public part of Key.
Certificate *x509.Certificate
Expand Down Expand Up @@ -117,7 +117,7 @@ type ServiceProvider struct {
// to verify signatures.
SignatureVerifier SignatureVerifier

// SignatureMethod, if non-empty, authentication requests will be signed
// SignatureMethod, if non-empty, authentication requests will be signed. "auto" will determine method based on certificate type.
SignatureMethod string

// LogoutBindings specify the bindings available for SLO endpoint. If empty,
Expand All @@ -141,6 +141,11 @@ const DefaultValidDuration = time.Hour * 24 * 2
// DefaultCacheDuration is how long we ask the IDP to cache the SP metadata.
const DefaultCacheDuration = time.Hour * 24 * 1

// SignRequests returns true if the service provider should sign requests.
func (sp *ServiceProvider) SignRequests() bool {
return len(sp.SignatureMethod) > 0
}

// Metadata returns the service provider metadata
func (sp *ServiceProvider) Metadata() *EntityDescriptor {
validDuration := DefaultValidDuration
Expand Down Expand Up @@ -245,6 +250,19 @@ func (sp *ServiceProvider) MakeRedirectAuthenticationRequest(relayState string)
return req.Redirect(relayState, sp)
}

// GetSignatureMethod returns the appropriate string to represent the
// signature method for the service provider.
func (sp *ServiceProvider) GetSignatureMethod() (string, error) {
if sp.SignatureMethod == "auto" {
signingContext, err := GetSigningContext(sp)
if err != nil {
return "auto", err
}
return signingContext.GetSignatureMethodIdentifier(), nil
}
return sp.SignatureMethod, nil
}

// Redirect returns a URL suitable for using the redirect binding with the request
func (r *AuthnRequest) Redirect(relayState string, sp *ServiceProvider) (*url.URL, error) {
w := &bytes.Buffer{}
Expand Down Expand Up @@ -274,13 +292,16 @@ func (r *AuthnRequest) Redirect(relayState string, sp *ServiceProvider) (*url.UR
if relayState != "" {
query += "&RelayState=" + relayState
}
if len(sp.SignatureMethod) > 0 {
query += "&SigAlg=" + url.QueryEscape(sp.SignatureMethod)
if sp.SignRequests() {
signingContext, err := GetSigningContext(sp)

if err != nil {
return nil, err
}
sigMethod, err := sp.GetSignatureMethod()
if err != nil {
return nil, err
}
query += "&SigAlg=" + url.QueryEscape(sigMethod)

sig, err := signingContext.SignString(query)
if err != nil {
Expand Down Expand Up @@ -391,7 +412,7 @@ func (sp *ServiceProvider) MakeArtifactResolveRequest(artifactID string) (*Artif
Artifact: artifactID,
}

if len(sp.SignatureMethod) > 0 {
if sp.SignRequests() {
if err := sp.SignArtifactResolve(&req); err != nil {
return nil, err
}
Expand Down Expand Up @@ -428,7 +449,7 @@ func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string, binding stri
RequestedAuthnContext: sp.RequestedAuthnContext,
}
// We don't need to sign the XML document if the IDP uses HTTP-Redirect binding
if len(sp.SignatureMethod) > 0 && binding == HTTPPostBinding {
if sp.SignRequests() && binding == HTTPPostBinding {
if err := sp.SignAuthnRequest(&req); err != nil {
return nil, err
}
Expand All @@ -449,13 +470,13 @@ func GetSigningContext(sp *ServiceProvider) (*dsig.SigningContext, error) {
// }
keyStore := dsig.TLSCertKeyStore(keyPair)

if sp.SignatureMethod != dsig.RSASHA1SignatureMethod &&
sp.SignatureMethod != dsig.RSASHA256SignatureMethod &&
sp.SignatureMethod != dsig.RSASHA512SignatureMethod {
return nil, fmt.Errorf("invalid signing method %s", sp.SignatureMethod)
signer, _ := sp.Key.(crypto.Signer)
chain, _ := keyStore.GetChain()
signingContext, err := dsig.NewSigningContext(signer, chain)
if err != nil {
return nil, err
}
signatureMethod := sp.SignatureMethod
signingContext := dsig.NewDefaultSigningContext(keyStore)
signatureMethod := signingContext.GetSignatureMethodIdentifier()
signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
return nil, err
Expand Down Expand Up @@ -1170,13 +1191,13 @@ func (sp *ServiceProvider) SignLogoutRequest(req *LogoutRequest) error {
// }
keyStore := dsig.TLSCertKeyStore(keyPair)

if sp.SignatureMethod != dsig.RSASHA1SignatureMethod &&
sp.SignatureMethod != dsig.RSASHA256SignatureMethod &&
sp.SignatureMethod != dsig.RSASHA512SignatureMethod {
return fmt.Errorf("invalid signing method %s", sp.SignatureMethod)
signer, _ := sp.Key.(crypto.Signer)
chain, _ := keyStore.GetChain()
signingContext, err := dsig.NewSigningContext(signer, chain)
if err != nil {
return err
}
signatureMethod := sp.SignatureMethod
signingContext := dsig.NewDefaultSigningContext(keyStore)
signatureMethod := signingContext.GetSignatureMethodIdentifier()
signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
return err
Expand Down Expand Up @@ -1213,7 +1234,7 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ
SPNameQualifier: sp.Metadata().EntityID,
},
}
if len(sp.SignatureMethod) > 0 {
if sp.SignRequests() {
if err := sp.SignLogoutRequest(&req); err != nil {
return nil, err
}
Expand Down Expand Up @@ -1327,7 +1348,7 @@ func (sp *ServiceProvider) MakeLogoutResponse(idpURL, logoutRequestID string) (*
},
}

if len(sp.SignatureMethod) > 0 {
if sp.SignRequests() {
if err := sp.SignLogoutResponse(&response); err != nil {
return nil, err
}
Expand Down Expand Up @@ -1435,13 +1456,13 @@ func (sp *ServiceProvider) SignLogoutResponse(resp *LogoutResponse) error {
// }
keyStore := dsig.TLSCertKeyStore(keyPair)

if sp.SignatureMethod != dsig.RSASHA1SignatureMethod &&
sp.SignatureMethod != dsig.RSASHA256SignatureMethod &&
sp.SignatureMethod != dsig.RSASHA512SignatureMethod {
return fmt.Errorf("invalid signing method %s", sp.SignatureMethod)
signer, _ := sp.Key.(crypto.Signer)
chain, _ := keyStore.GetChain()
signingContext, err := dsig.NewSigningContext(signer, chain)
if err != nil {
return err
}
signatureMethod := sp.SignatureMethod
signingContext := dsig.NewDefaultSigningContext(keyStore)
signatureMethod := signingContext.GetSignatureMethodIdentifier()
signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions service_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ func TestSPFailToProduceSignedRequestWithBogusSignatureMethod(t *testing.T) {
assert.Check(t, err)

_, err = s.MakeRedirectAuthenticationRequest("relayState")
assert.Check(t, is.ErrorContains(err, "invalid signing method bogus"))
assert.Check(t, is.ErrorContains(err, "unknown SignatureMethod: bogus"))
}

func TestSPCanProducePostLogoutRequest(t *testing.T) {
Expand Down Expand Up @@ -1665,7 +1665,7 @@ func TestMakeSignedArtifactResolveRequestWithBogusSignatureMethod(t *testing.T)
}

_, err := sp.MakeArtifactResolveRequest("artifactId")
assert.Check(t, is.ErrorContains(err, "invalid signing method bogus"))
assert.Check(t, is.ErrorContains(err, "unknown SignatureMethod: bogus"))

}

Expand Down
Loading