Skip to content

Commit

Permalink
cache key and storage changes; default authnscheme
Browse files Browse the repository at this point in the history
  • Loading branch information
Manoj Ampalam committed Aug 2, 2023
1 parent b027b5c commit 4e69190
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 30 deletions.
4 changes: 3 additions & 1 deletion apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,9 @@ func (cca Client) AcquireTokenByCredential(ctx context.Context, scopes []string,
authParams.Scopes = scopes
authParams.AuthorizationType = authority.ATClientCredentials
authParams.Claims = o.claims
authParams.AuthnScheme = o.authnScheme
if o.authnScheme != nil {
authParams.AuthnScheme = o.authnScheme
}
token, err := cca.base.Token.Credential(ctx, authParams, cca.cred)
if err != nil {
return AuthResult{}, err
Expand Down
9 changes: 3 additions & 6 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,6 @@ type AuthResult struct {
}

func (ar *AuthResult) ApplyAuthnScheme(params *authority.AuthParams) (AuthResult, error) {

if params.AuthnScheme == nil {
return *ar, nil
}

result := *ar
var err error
result.AccessToken, err = params.AuthnScheme.FormatAccessToken(ar.AccessToken)
Expand Down Expand Up @@ -302,7 +297,9 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
authParams.AuthorizationType = silent.AuthorizationType
authParams.Claims = silent.Claims
authParams.UserAssertion = silent.UserAssertion
authParams.AuthnScheme = silent.AuthnScheme
if silent.AuthnScheme != nil {
authParams.AuthnScheme = silent.AuthnScheme
}

m := b.pmanager
if authParams.AuthorizationType != authority.ATOnBehalfOf {
Expand Down
8 changes: 5 additions & 3 deletions apps/internal/base/internal/storage/items.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,14 @@ type AccessToken struct {
ExtendedExpiresOn internalTime.Unix `json:"extended_expires_on,omitempty"`
CachedAt internalTime.Unix `json:"cached_at,omitempty"`
UserAssertionHash string `json:"user_assertion_hash,omitempty"`
AuthnSchemeKeyID string `json:"authentication_scheme_keyid,omitempty"`
TokenType string `json:"token_type,omitempty"`
AuthnSchemeKeyID string `json:"keyid,omitempty"`

AdditionalFields map[string]interface{}
}

// NewAccessToken is the constructor for AccessToken.
func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, extendedExpiresOn time.Time, scopes, token, authnSchemeKeyID string) AccessToken {
func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, extendedExpiresOn time.Time, scopes, token, tokenType, authnSchemeKeyID string) AccessToken {
return AccessToken{
HomeAccountID: homeID,
Environment: env,
Expand All @@ -93,14 +94,15 @@ func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, ex
CachedAt: internalTime.Unix{T: cachedAt.UTC()},
ExpiresOn: internalTime.Unix{T: expiresOn.UTC()},
ExtendedExpiresOn: internalTime.Unix{T: extendedExpiresOn.UTC()},
TokenType: tokenType,
AuthnSchemeKeyID: authnSchemeKeyID,
}
}

// Key outputs the key that can be used to uniquely look up this entry in a map.
func (a AccessToken) Key() string {
key := strings.Join(
[]string{a.HomeAccountID, a.Environment, a.CredentialType, a.ClientID, a.Realm, a.Scopes, a.AuthnSchemeKeyID},
[]string{a.HomeAccountID, a.Environment, a.CredentialType, a.ClientID, a.Realm, a.Scopes, a.TokenType},
shared.CacheKeySeparator,
)
return strings.ToLower(key)
Expand Down
13 changes: 6 additions & 7 deletions apps/internal/base/internal/storage/partitioned_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,8 @@ func (m *PartitionedManager) Read(ctx context.Context, authParameters authority.
realm := authParameters.AuthorityInfo.Tenant
clientID := authParameters.ClientID
scopes := authParameters.Scopes
authnSchemeKeyID := ""
if authParameters.AuthnScheme != nil {
authnSchemeKeyID = authParameters.AuthnScheme.KeyId()
}
authnSchemeKeyID := authParameters.AuthnScheme.KeyId()
tokenType := authParameters.AuthnScheme.AccessTokenType()

// fetch metadata if instanceDiscovery is enabled
aliases := []string{authParameters.AuthorityInfo.Host}
Expand All @@ -61,8 +59,8 @@ func (m *PartitionedManager) Read(ctx context.Context, authParameters authority.

// errors returned by read* methods indicate a cache miss and are therefore non-fatal. We continue populating
// TokenResponse fields so that e.g. lack of an ID token doesn't prevent the caller from receiving a refresh token.
accessToken, err := m.readAccessToken(aliases, realm, clientID, userAssertionHash, scopes, partitionKeyFromRequest+authnSchemeKeyID)
if err == nil {
accessToken, err := m.readAccessToken(aliases, realm, clientID, userAssertionHash, scopes, partitionKeyFromRequest+tokenType)
if err == nil && accessToken.AuthnSchemeKeyID == authnSchemeKeyID {
tr.AccessToken = accessToken
}
idToken, err := m.readIDToken(aliases, realm, clientID, userAssertionHash, getPartitionKeyIDTokenRead(accessToken))
Expand Down Expand Up @@ -123,6 +121,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes
tokenResponse.ExtExpiresOn.T,
target,
tokenResponse.AccessToken,
tokenResponse.TokenType,
authnSchemeKeyID,
)
if authParameters.AuthorizationType == authority.ATOnBehalfOf {
Expand All @@ -131,7 +130,7 @@ func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenRes

// Since we have a valid access token, cache it before moving on.
if err := accessToken.Validate(); err == nil {
if err := m.writeAccessToken(accessToken, getPartitionKeyAccessToken(accessToken)+authnSchemeKeyID); err != nil {
if err := m.writeAccessToken(accessToken, getPartitionKeyAccessToken(accessToken)+tokenResponse.TokenType); err != nil {
return account, err
}
} else {
Expand Down
21 changes: 11 additions & 10 deletions apps/internal/base/internal/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,8 @@ func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams)
realm := authParameters.AuthorityInfo.Tenant
clientID := authParameters.ClientID
scopes := authParameters.Scopes
authnSchemeKeyID := ""
if authParameters.AuthnScheme != nil {
authnSchemeKeyID = authParameters.AuthnScheme.KeyId()
}
authnSchemeKeyID := authParameters.AuthnScheme.KeyId()
tokenType := authParameters.AuthnScheme.AccessTokenType()

// fetch metadata if instanceDiscovery is enabled
aliases := []string{authParameters.AuthorityInfo.Host}
Expand All @@ -104,7 +102,7 @@ func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams)
aliases = metadata.Aliases
}

accessToken := m.readAccessToken(homeAccountID, aliases, realm, clientID, scopes, authnSchemeKeyID)
accessToken := m.readAccessToken(homeAccountID, aliases, realm, clientID, scopes, tokenType, authnSchemeKeyID)
tr.AccessToken = accessToken

if homeAccountID == "" {
Expand Down Expand Up @@ -170,6 +168,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces
tokenResponse.ExtExpiresOn.T,
target,
tokenResponse.AccessToken,
tokenResponse.TokenType,
authnSchemeKeyID,
)

Expand Down Expand Up @@ -258,17 +257,19 @@ func (m *Manager) aadMetadata(ctx context.Context, authorityInfo authority.Info)
return m.aadCache[authorityInfo.Host], nil
}

func (m *Manager) readAccessToken(homeID string, envAliases []string, realm, clientID string, scopes []string, authnSchemeKeyID string) AccessToken {
func (m *Manager) readAccessToken(homeID string, envAliases []string, realm, clientID string, scopes []string, tokenType, authnSchemeKeyID string) AccessToken {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
// TODO: linear search (over a map no less) is slow for a large number (thousands) of tokens.
// this shows up as the dominating node in a profile. for real-world scenarios this likely isn't
// an issue, however if it does become a problem then we know where to look.
for _, at := range m.contract.AccessTokens {
if at.HomeAccountID == homeID && at.Realm == realm && at.ClientID == clientID && at.AuthnSchemeKeyID == authnSchemeKeyID {
if checkAlias(at.Environment, envAliases) {
if isMatchingScopes(scopes, at.Scopes) {
return at
if at.HomeAccountID == homeID && at.Realm == realm && at.ClientID == clientID {
if at.TokenType == tokenType && at.AuthnSchemeKeyID == authnSchemeKeyID {
if checkAlias(at.Environment, envAliases) {
if isMatchingScopes(scopes, at.Scopes) {
return at
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions apps/internal/oauth/ops/accesstokens/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ type TokenResponse struct {

AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`

FamilyID string `json:"foci"`
IDToken IDToken `json:"id_token"`
Expand Down
30 changes: 28 additions & 2 deletions apps/internal/oauth/ops/authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,36 @@ const (
ADFS = "ADFS"
)

// AuthenticationScheme interface
type AuthenticationScheme interface {
// Extra parameters that are added to the request to the /token endpoint.
GetTokenRequestParams() map[string]string
// Key ID of the public / private key pair used by the encryption algorithm, if any.
// Tokens obtained by authentication schemes that use this are bound to the KeyId, i.e.
// if a different kid is presented, the access token cannot be used.
KeyId() string
// Creates the access token that goes into an Authorization HTTP header.
FormatAccessToken(accessToken string) (string, error)
//Expected to match the token_type parameter returned by ESTS. Used to disambiguate
// between ATs of different types (e.g. Bearer and PoP) when loading from cache etc.
AccessTokenType() string
}

// default authn scheme realizing AuthenticationScheme for "Bearer" tokens
type BearerAuthenticationScheme struct{}

var bearerAuthnScheme BearerAuthenticationScheme

func (ba *BearerAuthenticationScheme) GetTokenRequestParams() map[string]string {
return nil
}
func (ba *BearerAuthenticationScheme) KeyId() string {
return ""
}
func (ba *BearerAuthenticationScheme) FormatAccessToken(accessToken string) (string, error) {
return accessToken, nil
}
func (ba *BearerAuthenticationScheme) AccessTokenType() string {
return "Bearer"
}

// AuthParams represents the parameters used for authorization for token acquisition.
Expand Down Expand Up @@ -187,7 +212,7 @@ type AuthParams struct {
LoginHint string
// DomainHint is a directive that can be used to accelerate the user to their federated IdP sign-in page
DomainHint string
// Authn Scheme
// Optional scheme passed by callers to custom format access token
AuthnScheme AuthenticationScheme
}

Expand All @@ -197,6 +222,7 @@ func NewAuthParams(clientID string, authorityInfo Info) AuthParams {
ClientID: clientID,
AuthorityInfo: authorityInfo,
CorrelationID: uuid.New().String(),
AuthnScheme: &bearerAuthnScheme,
}
}

Expand Down
4 changes: 3 additions & 1 deletion apps/public/public.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,9 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string,
authParams.DomainHint = o.domainHint
authParams.State = uuid.New().String()
authParams.Prompt = "select_account"
authParams.AuthnScheme = o.authnScheme
if o.authnScheme != nil {
authParams.AuthnScheme = o.authnScheme
}
res, err := pca.browserLogin(ctx, redirectURL, authParams, o.openURL)
if err != nil {
return AuthResult{}, err
Expand Down

0 comments on commit 4e69190

Please sign in to comment.