Skip to content

Commit

Permalink
feat(oauth): add support for dSTS authority type
Browse files Browse the repository at this point in the history
  • Loading branch information
handsomejack-42 committed Apr 15, 2024
1 parent 917001c commit 83ca36e
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 27 deletions.
74 changes: 72 additions & 2 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/kylelemons/godebug/pretty"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/golang-jwt/jwt/v5"
"github.com/kylelemons/godebug/pretty"
)

// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
Expand Down Expand Up @@ -1332,3 +1333,72 @@ func TestWithAuthenticationScheme(t *testing.T) {
t.Fatalf(`unexpected access token "%s"`, result.AccessToken)
}
}

func TestAcquireTokenByCredentialFromDSTS(t *testing.T) {
tests := map[string]struct {
cred string
}{
"secret": {cred: "fake_secret"},
"signed assertion": {cred: "fake_assertion"},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
cred, err := NewCredFromSecret(test.cred)
if err != nil {
t.Fatal(err)
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "Bearer",
}, cred, "https://fake_authority/dstsv2/fake_tenant")
if err != nil {
t.Fatal(err)
}

// expect first attempt to fail
_, err = client.AcquireTokenSilent(context.Background(), tokenScope)
if err == nil {
t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err)
}

tk, err := client.AcquireTokenByCredential(context.Background(), tokenScope)
if err != nil {
t.Errorf("got err == %s, want err == nil", err)
}
if tk.AccessToken != token {
t.Errorf("unexpected access token %s", tk.AccessToken)
}

tk, err = client.AcquireTokenSilent(context.Background(), tokenScope)
if err != nil {
t.Errorf("got err == %s, want err == nil", err)
}
if tk.AccessToken != token {
t.Errorf("unexpected access token %s", tk.AccessToken)
}

// fail for another tenant
tk, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("other"))
if err == nil {
t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err)
}
})
}

//// silent authentication should now succeed for the given tenant...
//if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err != nil {
// t.Fatal(err)
//}
//if ar.AccessToken != accessToken {
// t.Fatal("cached access token should match the one returned by AcquireToken...")
//}
//// ...but fail for another tenant
//otherTenant := "not-" + tenant
//if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(otherTenant)); err == nil {
// t.Fatal("expected an error")
//}
}
7 changes: 7 additions & 0 deletions apps/internal/oauth/fake/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ func (f Authority) AADInstanceDiscovery(ctx context.Context, info authority.Info
return f.InstanceResp, nil
}

func (f Authority) DSTSInstanceDiscovery(_ context.Context, _ authority.Info) (authority.InstanceDiscoveryResponse, error) {
if f.Err {
return authority.InstanceDiscoveryResponse{}, errors.New("error")
}
return f.InstanceResp, nil
}

// WSTrust is a fake implementation of the oauth.fetchWSTrust interface.
type WSTrust struct {
// Set these to true to have their respective APIs return an error.
Expand Down
4 changes: 3 additions & 1 deletion apps/internal/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"io"
"time"

"github.com/google/uuid"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
Expand All @@ -18,7 +20,6 @@ import (
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
"github.com/google/uuid"
)

// ResolveEndpointer contains the methods for resolving authority endpoints.
Expand All @@ -44,6 +45,7 @@ type AccessTokens interface {
type FetchAuthority interface {
UserRealm(context.Context, authority.AuthParams) (authority.UserRealm, error)
AADInstanceDiscovery(context.Context, authority.Info) (authority.InstanceDiscoveryResponse, error)
DSTSInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error)
}

// FetchWSTrust contains the methods for interacting with WSTrust endpoints.
Expand Down
55 changes: 38 additions & 17 deletions apps/internal/oauth/ops/authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
const (
authorizationEndpoint = "https://%v/%v/oauth2/v2.0/authorize"
aadInstanceDiscoveryEndpoint = "https://%v/common/discovery/instance"
dstsInstanceDiscoveryEndpoint = "https://%v/dstsv2/common/discovery/instance"
tenantDiscoveryEndpointWithRegion = "https://%s.%s/%s/v2.0/.well-known/openid-configuration"
regionName = "REGION_NAME"
defaultAPIVersion = "2021-10-01"
Expand Down Expand Up @@ -137,6 +138,7 @@ const (
const (
AAD = "MSSTS"
ADFS = "ADFS"
DSTS = "DSTS"
)

// AuthenticationScheme is an extensibility mechanism designed to be used only by Azure Arc for proof of possession access tokens.
Expand Down Expand Up @@ -252,6 +254,8 @@ func (p AuthParams) WithTenant(ID string) (AuthParams, error) {
authority = "https://" + path.Join(p.AuthorityInfo.Host, ID)
case ADFS:
return p, errors.New("ADFS authority doesn't support tenants")
case DSTS:
authority = "https://" + path.Join(p.AuthorityInfo.Host, "dstsv2", ID)
}

info, err := NewInfoFromAuthorityURI(authority, p.AuthorityInfo.ValidateAuthority, p.AuthorityInfo.InstanceDiscoveryDisabled)
Expand Down Expand Up @@ -351,35 +355,40 @@ type Info struct {
InstanceDiscoveryDisabled bool
}

func firstPathSegment(u *url.URL) (string, error) {
pathParts := strings.Split(u.EscapedPath(), "/")
if len(pathParts) >= 2 {
return pathParts[1], nil
}

return "", errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/<your tenant>"`)
}

// NewInfoFromAuthorityURI creates an AuthorityInfo instance from the authority URL provided.
func NewInfoFromAuthorityURI(authority string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) {
u, err := url.Parse(strings.ToLower(authority))
if err != nil || u.Scheme != "https" {
return Info{}, errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/<your tenant>"`)
if err != nil {
return Info{}, fmt.Errorf("couldn't parse authority url: %w", err)
}
if u.Scheme != "https" {
return Info{}, errors.New("authority url scheme must be https")
}

tenant, err := firstPathSegment(u)
if err != nil {
return Info{}, err
pathParts := strings.Split(u.EscapedPath(), "/")
if len(pathParts) < 2 {
return Info{}, errors.New(`authority must be an URL such as "https://login.microsoftonline.com/<your tenant>"`)
}
authorityType := AAD
if tenant == "adfs" {

var authorityType, tenant string
switch pathParts[1] {
case "adfs":
authorityType = ADFS
case "dstsv2":
if len(pathParts) != 3 {
return Info{}, errors.New(`dSTS authority must be an https URL such as "https://<authority>/dstsv2/<your tenant>"`)
}
authorityType = DSTS
tenant = pathParts[2]
default:
authorityType = AAD
tenant = pathParts[1]
}

// u.Host includes the port, if any, which is required for private cloud deployments
return Info{
Host: u.Host,
CanonicalAuthorityURI: fmt.Sprintf("https://%v/%v/", u.Host, tenant),
CanonicalAuthorityURI: authority,
AuthorityType: authorityType,
ValidateAuthority: validateAuthority,
Tenant: tenant,
Expand Down Expand Up @@ -530,6 +539,18 @@ func (c Client) AADInstanceDiscovery(ctx context.Context, authorityInfo Info) (I
return resp, err
}

func (c Client) DSTSInstanceDiscovery(ctx context.Context, authorityInfo Info) (InstanceDiscoveryResponse, error) {
qv := url.Values{}
qv.Set("api-version", "1.1")
qv.Set("authorization_endpoint", fmt.Sprintf(authorizationEndpoint, authorityInfo.Host, authorityInfo.Tenant))

endpoint := fmt.Sprintf(dstsInstanceDiscoveryEndpoint, authorityInfo.Host)

resp := InstanceDiscoveryResponse{}
err := c.Comm.JSONCall(ctx, endpoint, http.Header{}, qv, nil, &resp)
return resp, err
}

func detectRegion(ctx context.Context) string {
region := os.Getenv(regionName)
if region != "" {
Expand Down
7 changes: 4 additions & 3 deletions apps/internal/oauth/ops/authority/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,10 @@ func TestAuthParamsWithTenant(t *testing.T) {
"do nothing if tenant override is empty": {authority: host + uuid1, tenant: "", expectedAuthority: host + uuid1},
"do nothing if tenant override equals tenant": {authority: host + uuid1, tenant: uuid1, expectedAuthority: host + uuid1},

"override common to tenant": {authority: host + "common", tenant: uuid1, expectedAuthority: host + uuid1},
"override organizations to tenant": {authority: host + "organizations", tenant: uuid1, expectedAuthority: host + uuid1},
"override tenant to tenant2": {authority: host + uuid1, tenant: uuid2, expectedAuthority: host + uuid2},
"override common to tenant": {authority: host + "common", tenant: uuid1, expectedAuthority: host + uuid1},
"override organizations to tenant": {authority: host + "organizations", tenant: uuid1, expectedAuthority: host + uuid1},
"override tenant to tenant2": {authority: host + uuid1, tenant: uuid2, expectedAuthority: host + uuid2},
"override tenant to tenant2 for dSTS": {authority: host + "dstsv2/" + uuid1, tenant: uuid2, expectedAuthority: host + "dstsv2/" + uuid2},

"tenant can't be common for AAD": {authority: host + uuid1, tenant: "common", expectError: true},
"tenant can't be consumers for AAD": {authority: host + uuid1, tenant: "consumers", expectError: true},
Expand Down
13 changes: 9 additions & 4 deletions apps/internal/oauth/resolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo
return endpoints, nil
}

endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo, userPrincipalName)
endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo)
if err != nil {
return authority.Endpoints{}, err
}
Expand Down Expand Up @@ -119,9 +119,15 @@ func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, use
m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry
}

func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (string, error) {
if authorityInfo.Tenant == "adfs" {
func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info) (string, error) {
if authorityInfo.AuthorityType == authority.ADFS {
return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil
} else if authorityInfo.AuthorityType == authority.DSTS {
resp, err := m.rest.Authority().DSTSInstanceDiscovery(ctx, authorityInfo)
if err != nil {
return "", err
}
return resp.TenantDiscoveryEndpoint, err
} else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) {
resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
if err != nil {
Expand All @@ -134,7 +140,6 @@ func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, aut
return "", err
}
return resp.TenantDiscoveryEndpoint, nil

}

return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil
Expand Down

0 comments on commit 83ca36e

Please sign in to comment.