Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code clarity update #492

Merged
merged 6 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ func NewAuthHandlerFunc(ac Auth) (HandlerFunc, error) {
return h, err
}

// NoAuth returns a handler that does not perform any authentication.
func NoAuth() (HandlerFunc, error) {
return func(w http.ResponseWriter, r *http.Request) (context.Context, error) {
return r.Context(), nil
Expand Down Expand Up @@ -232,6 +233,7 @@ func NewAuth(ac Auth, log *zap.Logger, opt Options, hFn ...HandlerFunc) (
}, nil
}

// SimpleHandler is a simple auth handler that sets the user ID, provider and role
func SimpleHandler(ac Auth) (HandlerFunc, error) {
return func(_ http.ResponseWriter, r *http.Request) (context.Context, error) {
c := r.Context()
Expand All @@ -257,6 +259,7 @@ func SimpleHandler(ac Auth) (HandlerFunc, error) {

var Err401 = errors.New("401 unauthorized")

// HeaderHandler is a middleware that checks for a header value
func HeaderHandler(ac Auth) (HandlerFunc, error) {
hdr := ac.Header

Expand Down Expand Up @@ -287,14 +290,17 @@ func HeaderHandler(ac Auth) (HandlerFunc, error) {
}, nil
}

// IsAuth returns true if the context contains a user ID
func IsAuth(c context.Context) bool {
return c != nil && c.Value(core.UserIDKey) != nil
}

// UserID returns the user ID from the context
func UserID(c context.Context) interface{} {
return c.Value(core.UserIDKey)
}

// UserIDInt returns the user ID from the context as an int
func UserIDInt(c context.Context) int {
v, ok := UserID(c).(string)
if !ok {
Expand Down
7 changes: 6 additions & 1 deletion auth/internal/rails/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type Auth struct {
AuthSalt string
}

// NewAuth creates a new Auth instance
func NewAuth(version, secret string) (*Auth, error) {
ra := &Auth{
Secret: secret,
Expand Down Expand Up @@ -60,6 +61,7 @@ func NewAuth(version, secret string) (*Auth, error) {
return ra, nil
}

// ParseCookie parses the rails cookie and returns the user ID
func (ra Auth) ParseCookie(cookie string) (userID string, err error) {
var dcookie []byte

Expand Down Expand Up @@ -87,6 +89,7 @@ func (ra Auth) ParseCookie(cookie string) (userID string, err error) {
return
}

// ParseCookie parses the rails cookie and returns the user ID
func ParseCookie(cookie string) (string, error) {
if cookie[0] != '{' {
return getUserId4([]byte(cookie))
Expand All @@ -95,6 +98,7 @@ func ParseCookie(cookie string) (string, error) {
return getUserId([]byte(cookie))
}

// getUserId extracts the user ID from the session data
func getUserId(data []byte) (userID string, err error) {
var sessionData map[string]interface{}

Expand Down Expand Up @@ -135,10 +139,11 @@ func getUserId(data []byte) (userID string, err error) {
return
}

// getUserId4 extracts the user ID from the session data
func getUserId4(data []byte) (userID string, err error) {
sessionData, err := marshal.CreateMarshalledObject(data).GetAsMap()
if err != nil {
return
return "", err
}

wardenData, ok := sessionData["warden.user.user.key"]
Expand Down
2 changes: 2 additions & 0 deletions auth/internal/rails/cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"golang.org/x/crypto/pbkdf2"
)

// parseCookie decrypts and parses a Rails session cookie
func parseCookie(cookie, secretKeyBase, salt, signSalt string) ([]byte, error) {
return session.DecryptSignedCookie(
cookie,
Expand All @@ -22,6 +23,7 @@ func parseCookie(cookie, secretKeyBase, salt, signSalt string) ([]byte, error) {

// {"session_id":"a71d6ffcd4ed5572ea2097f569eb95ef","warden.user.user.key":[[2],"$2a$11$q9Br7m4wJxQvF11hAHvTZO"],"_csrf_token":"HsYgrD2YBaWAabOYceN0hluNRnGuz49XiplmMPt43aY="}

// parseCookie52 decrypts and parses a Rails 5.2+ session cookie
func parseCookie52(cookie, secretKeyBase, authSalt string) ([]byte, error) {
ecookie, err := url.QueryUnescape(cookie)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ const (
authHeader = "Authorization"
)

// JwtHandler is a middleware that checks for a JWT token in the cookie or the
// authorization header. If the token is found, it is validated and the claims
func JwtHandler(ac Auth) (HandlerFunc, error) {
jwtProvider, err := provider.NewProvider(ac.JWT)
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions auth/provider/auth0.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type Auth0Provider struct {
issuer string
}

// NewAuth0Provider creates a new Auth0 JWT provider
func NewAuth0Provider(config JWTConfig) (*Auth0Provider, error) {
key, err := getKey(config)
if err != nil {
Expand All @@ -27,26 +28,30 @@ func NewAuth0Provider(config JWTConfig) (*Auth0Provider, error) {
}, nil
}

// KeyFunc returns a function that returns the key used to verify the JWT token
func (p *Auth0Provider) KeyFunc() jwt.Keyfunc {
return func(token *jwt.Token) (interface{}, error) {
return p.key, nil
}
}

// VerifyAudience checks if the audience claim is valid
func (p *Auth0Provider) VerifyAudience(claims jwt.MapClaims) bool {
if claims == nil {
return false
}
return claims.VerifyAudience(p.aud, p.aud != "")
}

// VerifyIssuer checks if the issuer claim is valid
func (p *Auth0Provider) VerifyIssuer(claims jwt.MapClaims) bool {
if claims == nil {
return false
}
return claims.VerifyIssuer(p.issuer, p.issuer != "")
}

// SetContextValues sets the user ID and provider in the context
func (p *Auth0Provider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) {
if claims == nil {
return ctx, errors.New("undefined claims")
Expand Down
6 changes: 6 additions & 0 deletions auth/provider/firebase.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type FirebaseProvider struct {
issuer string
}

// NewFirebaseProvider creates a new Firebase JWT provider
func NewFirebaseProvider(config JWTConfig) (*FirebaseProvider, error) {
issuer := config.Issuer
if issuer == "" {
Expand All @@ -46,24 +47,28 @@ func NewFirebaseProvider(config JWTConfig) (*FirebaseProvider, error) {
}, nil
}

// KeyFunc returns a function that returns the key used to verify the JWT token
func (p *FirebaseProvider) KeyFunc() jwt.Keyfunc {
return firebaseKeyFunction
}

// VerifyAudience checks if the audience claim is valid
func (p *FirebaseProvider) VerifyAudience(claims jwt.MapClaims) bool {
if claims == nil {
return false
}
return claims.VerifyAudience(p.aud, p.aud != "")
}

// VerifyIssuer checks if the issuer claim is valid
func (p *FirebaseProvider) VerifyIssuer(claims jwt.MapClaims) bool {
if claims == nil {
return false
}
return claims.VerifyIssuer(p.issuer, p.issuer != "")
}

// SetContextValues sets the user ID and provider in the context
func (p *FirebaseProvider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) {
if claims == nil {
return ctx, errors.New("undefined claims")
Expand All @@ -85,6 +90,7 @@ func (e *firebaseKeyError) Error() string {
return e.Message + " " + e.Err.Error()
}

// firebaseKeyFunction returns the public key used to verify the JWT token
func firebaseKeyFunction(token *jwt.Token) (interface{}, error) {
kid, ok := token.Header["kid"]

Expand Down
5 changes: 5 additions & 0 deletions auth/provider/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type GenericProvider struct {
issuer string
}

// NewGenericProvider creates a new generic JWT provider
func NewGenericProvider(config JWTConfig) (*GenericProvider, error) {
key, err := getKey(config)
if err != nil {
Expand All @@ -26,26 +27,30 @@ func NewGenericProvider(config JWTConfig) (*GenericProvider, error) {
}, nil
}

// KeyFunc returns a function that returns the key used to verify the JWT token
func (p *GenericProvider) KeyFunc() jwt.Keyfunc {
return func(token *jwt.Token) (interface{}, error) {
return p.key, nil
}
}

// VerifyAudience verifies the audience claim of the JWT token
func (p *GenericProvider) VerifyAudience(claims jwt.MapClaims) bool {
if claims == nil {
return false
}
return claims.VerifyAudience(p.aud, p.aud != "")
}

// VerifyIssuer verifies the issuer claim of the JWT token
func (p *GenericProvider) VerifyIssuer(claims jwt.MapClaims) bool {
if claims == nil {
return false
}
return claims.VerifyIssuer(p.issuer, p.issuer != "")
}

// SetContextValues sets the user ID and provider in the context
func (p *GenericProvider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) {
if claims == nil {
return ctx, errors.New("undefined claims")
Expand Down
7 changes: 7 additions & 0 deletions auth/provider/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type keychainCache struct {
semaphore int32
}

// newKeychainCache creates a new KeychainCache
func newKeychainCache(jwksURL string, refreshInterval, minRefreshInterval int) *keychainCache {
ar := jwk.NewAutoRefresh(context.Background())
if refreshInterval > 0 {
Expand All @@ -34,6 +35,7 @@ func newKeychainCache(jwksURL string, refreshInterval, minRefreshInterval int) *
}
}

// getKey returns the key from the cache
func (k *keychainCache) getKey(kid string) (interface{}, error) {
set, err := k.keyCache.Fetch(context.TODO(), k.jwksURL)
if err != nil {
Expand Down Expand Up @@ -89,6 +91,7 @@ type JWKSProvider struct {
cache *keychainCache
}

// NewJWKSProvider creates a new JWKSProvider
func NewJWKSProvider(config JWTConfig) (*JWKSProvider, error) {
if config.JWKSURL == "" {
return nil, errors.New("undefined JWKSURL")
Expand All @@ -100,6 +103,7 @@ func NewJWKSProvider(config JWTConfig) (*JWKSProvider, error) {
}, nil
}

// KeyFunc returns a function that returns the key used to verify the JWT token
func (p *JWKSProvider) KeyFunc() jwt.Keyfunc {
return func(token *jwt.Token) (interface{}, error) {
if token == nil {
Expand All @@ -123,20 +127,23 @@ func (p *JWKSProvider) KeyFunc() jwt.Keyfunc {
}
}

// VerifyAudience checks if the audience claim is valid
func (p *JWKSProvider) VerifyAudience(claims jwt.MapClaims) bool {
if claims == nil {
return false
}
return claims.VerifyAudience(p.aud, p.aud != "")
}

// VerifyIssuer checks if the issuer claim is valid
func (p *JWKSProvider) VerifyIssuer(claims jwt.MapClaims) bool {
if claims == nil {
return false
}
return claims.VerifyIssuer(p.issuer, p.issuer != "")
}

// SetContextValues sets the user ID and provider in the context
func (p *JWKSProvider) SetContextValues(ctx context.Context, claims jwt.MapClaims) (context.Context, error) {
if claims == nil {
return ctx, errors.New("undefined claims")
Expand Down
10 changes: 6 additions & 4 deletions auth/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type JWTProvider interface {
SetContextValues(context.Context, jwt.MapClaims) (context.Context, error)
}

// NewProvider creates a new JWT provider based on the config values
func NewProvider(config JWTConfig) (JWTProvider, error) {
switch config.Provider {
case "auth0":
Expand All @@ -64,20 +65,21 @@ func NewProvider(config JWTConfig) (JWTProvider, error) {
}
}

// getKey returns the key used to verify the JWT token
func getKey(config JWTConfig) (interface{}, error) {
var key interface{}
var err error

switch {
case config.PubKey != "":
pk := []byte(config.PubKey)
pubKey := []byte(config.PubKey)
switch config.PubKeyType {
case "ecdsa":
key, err = jwt.ParseECPublicKeyFromPEM(pk)
key, err = jwt.ParseECPublicKeyFromPEM(pubKey)
case "rsa":
key, err = jwt.ParseRSAPublicKeyFromPEM(pk)
key, err = jwt.ParseRSAPublicKeyFromPEM(pubKey)
default:
key, err = jwt.ParseECPublicKeyFromPEM(pk)
key, err = jwt.ParseECPublicKeyFromPEM(pubKey)
}
if err != nil {
return nil, err
Expand Down
5 changes: 5 additions & 0 deletions auth/rails.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/gomodule/redigo/redis"
)

// RailsHandler returns a handler that authenticates using a Rails session cookie
func RailsHandler(ac Auth) (HandlerFunc, error) {
ru := ac.Rails.URL

Expand All @@ -29,6 +30,7 @@ func RailsHandler(ac Auth) (HandlerFunc, error) {
return RailsCookieHandler(ac)
}

// RailsRedisHandler returns a handler that authenticates using a Rails session cookie
func RailsRedisHandler(ac Auth) (HandlerFunc, error) {
cookie := ac.Cookie

Expand Down Expand Up @@ -95,6 +97,7 @@ func RailsRedisHandler(ac Auth) (HandlerFunc, error) {
}, nil
}

// RailsMemcacheHandler returns a handler that authenticates using a Rails session cookie
func RailsMemcacheHandler(ac Auth) (HandlerFunc, error) {
cookie := ac.Cookie

Expand Down Expand Up @@ -138,6 +141,7 @@ func RailsMemcacheHandler(ac Auth) (HandlerFunc, error) {
}, nil
}

// RailsCookieHandler returns a handler that authenticates using a Rails session cookie
func RailsCookieHandler(ac Auth) (HandlerFunc, error) {
cookie := ac.Cookie
if len(cookie) == 0 {
Expand Down Expand Up @@ -168,6 +172,7 @@ func RailsCookieHandler(ac Auth) (HandlerFunc, error) {
}, nil
}

// railsAuth returns a new rails auth instance
func railsAuth(ac Auth) (*rails.Auth, error) {
secret := ac.Rails.SecretKeyBase
if len(secret) == 0 {
Expand Down
Loading
Loading