Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

467 feature request add dsts support #482

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
89 changes: 75 additions & 14 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ import (
"testing"
"time"

"github.com/golang-jwt/jwt/v5"
"github.com/kylelemons/godebug/pretty"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/golang-jwt/jwt/v5"
"github.com/kylelemons/godebug/pretty"
)

// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
Expand Down Expand Up @@ -65,18 +66,18 @@ func TestCertFromPEM(t *testing.T) {

const (
authorityFmt = "https://%s/%s"
fakeAuthority = "https://fake_authority/fake"
fakeAuthority = "https://fake_authority/fake_tenant"
fakeClientID = "fake_client_id"
fakeSecret = "fake_secret"
fakeTokenEndpoint = "https://fake_authority/fake/token"
fakeTokenEndpoint = "https://fake_authority/fake_tenant/token"
localhost = "http://localhost"
refresh = "fake_refresh"
token = "fake_token"
)

var tokenScope = []string{"the_scope"}

func fakeClient(tk accesstokens.TokenResponse, credential Credential, options ...Option) (Client, error) {
func fakeClient(tk accesstokens.TokenResponse, credential Credential, fakeAuthority string, options ...Option) (Client, error) {
client, err := New(fakeAuthority, fakeClientID, credential, options...)
if err != nil {
return Client{}, err
Expand All @@ -86,7 +87,7 @@ func fakeClient(tk accesstokens.TokenResponse, credential Credential, options ..
}
client.base.Token.Authority = &fake.Authority{
InstanceResp: authority.InstanceDiscoveryResponse{
TenantDiscoveryEndpoint: "https://fake_authority/fake/discovery/endpoint",
TenantDiscoveryEndpoint: fakeAuthority + "/discovery/endpoint",
Metadata: []authority.InstanceDiscoveryMetadata{
{
PreferredNetwork: "fake_authority",
Expand All @@ -104,8 +105,12 @@ func fakeClient(tk accesstokens.TokenResponse, credential Credential, options ..
},
}
client.base.Token.Resolver = &fake.ResolveEndpoints{
Endpoints: authority.NewEndpoints("https://fake_authority/fake/auth",
fakeTokenEndpoint, "https://fake_authority/fake/jwt", "fake_authority"),
Endpoints: authority.NewEndpoints(
fakeAuthority+"/auth",
fakeAuthority+"/token",
fakeAuthority+"/jwt",
fakeAuthority,
),
}
client.base.Token.WSTrust = &fake.WSTrust{}
return client, nil
Expand Down Expand Up @@ -137,7 +142,7 @@ func TestAcquireTokenByCredential(t *testing.T) {
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "Bearer",
}, cred)
}, cred, fakeAuthority)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -231,7 +236,7 @@ func TestAcquireTokenByAssertionCallback(t *testing.T) {
return "", errors.New("expected error")
}
cred := NewCredFromAssertionCallback(getAssertion)
client, err := fakeClient(accesstokens.TokenResponse{}, cred)
client, err := fakeClient(accesstokens.TokenResponse{}, cred, fakeAuthority)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -275,7 +280,7 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
Oid: "123-456",
TenantID: "fake",
Subject: "nothing",
Issuer: "https://fake_authority/fake",
Issuer: fakeAuthority,
Audience: "abc-123",
ExpirationTime: time.Now().Add(time.Hour).Unix(),
IssuedAt: time.Now().Add(-5 * time.Minute).Unix(),
Expand All @@ -290,7 +295,7 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
},
}

client, err := fakeClient(tr, cred)
client, err := fakeClient(tr, cred, fakeAuthority)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -517,7 +522,7 @@ func TestNewCredFromCert(t *testing.T) {
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
}, cred, opts...)
}, cred, fakeAuthority, opts...)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -1309,7 +1314,7 @@ func TestWithAuthenticationScheme(t *testing.T) {
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "TokenType",
}, cred)
}, cred, fakeAuthority)
if err != nil {
t.Fatal(err)
}
Expand All @@ -1328,3 +1333,59 @@ func TestWithAuthenticationScheme(t *testing.T) {
t.Fatalf(`unexpected access token "%s"`, result.AccessToken)
}
}

func TestAcquireTokenByCredentialFromDSTS(t *testing.T) {
tests := map[string]struct {
cred string
}{
"secret": {cred: "fake_secret"},
"signed assertion": {cred: "fake_assertion"},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
cred, err := NewCredFromSecret(test.cred)
if err != nil {
t.Fatal(err)
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "Bearer",
}, cred, "https://fake_authority/dstsv2/"+authority.DSTSTenant)
if err != nil {
t.Fatal(err)
}

// expect first attempt to fail
_, err = client.AcquireTokenSilent(context.Background(), tokenScope)
if err == nil {
t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err)
}

tk, err := client.AcquireTokenByCredential(context.Background(), tokenScope)
if err != nil {
t.Errorf("got err == %s, want err == nil", err)
}
if tk.AccessToken != token {
t.Errorf("unexpected access token %s", tk.AccessToken)
}

tk, err = client.AcquireTokenSilent(context.Background(), tokenScope)
if err != nil {
t.Errorf("got err == %s, want err == nil", err)
}
if tk.AccessToken != token {
t.Errorf("unexpected access token %s", tk.AccessToken)
}

// fail for another tenant

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tenant

Given that dSTS has a single tenant, there is no need for this test.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there still is a value in checking that we will fail the execution in case anyone tries to do WithTenant on dSTS flow

tk, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("other"))
if err == nil {
t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err)
}
})
}
}
3 changes: 2 additions & 1 deletion apps/internal/oauth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"io"
"time"

"github.com/google/uuid"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
Expand All @@ -18,7 +20,6 @@ import (
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
"github.com/google/uuid"
)

// ResolveEndpointer contains the methods for resolving authority endpoints.
Expand Down
81 changes: 47 additions & 34 deletions apps/internal/oauth/ops/authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (

const (
authorizationEndpoint = "https://%v/%v/oauth2/v2.0/authorize"
instanceDiscoveryEndpoint = "https://%v/common/discovery/instance"
aadInstanceDiscoveryEndpoint = "https://%v/common/discovery/instance"
tenantDiscoveryEndpointWithRegion = "https://%s.%s/%s/v2.0/.well-known/openid-configuration"
regionName = "REGION_NAME"
defaultAPIVersion = "2021-10-01"
Expand Down Expand Up @@ -137,8 +137,12 @@ const (
const (
AAD = "MSSTS"
ADFS = "ADFS"
DSTS = "DSTS"
)

// DSTSTenant is referenced throughout multiple files, let us use a const in case we ever need to change it.
const DSTSTenant = "7a433bfc-2514-4697-b467-e0933190487f"

// AuthenticationScheme is an extensibility mechanism designed to be used only by Azure Arc for proof of possession access tokens.
type AuthenticationScheme interface {
// Extra parameters that are added to the request to the /token endpoint.
Expand Down Expand Up @@ -236,23 +240,26 @@ func NewAuthParams(clientID string, authorityInfo Info) AuthParams {
// - the client is configured to authenticate only Microsoft accounts via the "consumers" endpoint
// - the resulting authority URL is invalid
func (p AuthParams) WithTenant(ID string) (AuthParams, error) {
switch ID {
case "", p.AuthorityInfo.Tenant:
// keep the default tenant because the caller didn't override it
if ID == "" || ID == p.AuthorityInfo.Tenant {
return p, nil
case "common", "consumers", "organizations":
if p.AuthorityInfo.AuthorityType == AAD {
}

var authority string
switch p.AuthorityInfo.AuthorityType {
case AAD:
if ID == "common" || ID == "consumers" || ID == "organizations" {
return p, fmt.Errorf(`tenant ID must be a specific tenant, not "%s"`, ID)
}
// else we'll return a better error below
}
if p.AuthorityInfo.AuthorityType != AAD {
return p, errors.New("the authority doesn't support tenants")
}
if p.AuthorityInfo.Tenant == "consumers" {
return p, errors.New(`client is configured to authenticate only personal Microsoft accounts, via the "consumers" endpoint`)
if p.AuthorityInfo.Tenant == "consumers" {
return p, errors.New(`client is configured to authenticate only personal Microsoft accounts, via the "consumers" endpoint`)
}
authority = "https://" + path.Join(p.AuthorityInfo.Host, ID)
case ADFS:
return p, errors.New("ADFS authority doesn't support tenants")
case DSTS:
return p, errors.New("dSTS authority doesn't support tenants")
}
authority := "https://" + path.Join(p.AuthorityInfo.Host, ID)

info, err := NewInfoFromAuthorityURI(authority, p.AuthorityInfo.ValidateAuthority, p.AuthorityInfo.InstanceDiscoveryDisabled)
if err == nil {
info.Region = p.AuthorityInfo.Region
Expand Down Expand Up @@ -344,44 +351,50 @@ type Info struct {
Host string
CanonicalAuthorityURI string
AuthorityType string
UserRealmURIPrefix string
ValidateAuthority bool
Tenant string
Region string
InstanceDiscoveryDisabled bool
}

func firstPathSegment(u *url.URL) (string, error) {
pathParts := strings.Split(u.EscapedPath(), "/")
if len(pathParts) >= 2 {
return pathParts[1], nil
}

return "", errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/<your tenant>"`)
}

// NewInfoFromAuthorityURI creates an AuthorityInfo instance from the authority URL provided.
func NewInfoFromAuthorityURI(authority string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) {
u, err := url.Parse(strings.ToLower(authority))
if err != nil || u.Scheme != "https" {
return Info{}, errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/<your tenant>"`)
if err != nil {
return Info{}, fmt.Errorf("couldn't parse authority url: %w", err)
}
if u.Scheme != "https" {
return Info{}, errors.New("authority url scheme must be https")
}

tenant, err := firstPathSegment(u)
if err != nil {
return Info{}, err
pathParts := strings.Split(u.EscapedPath(), "/")
if len(pathParts) < 2 {
return Info{}, errors.New(`authority must be an URL such as "https://login.microsoftonline.com/<your tenant>"`)
}
authorityType := AAD
if tenant == "adfs" {

var authorityType, tenant string
switch pathParts[1] {
case "adfs":
authorityType = ADFS
case "dstsv2":
if len(pathParts) != 3 {
return Info{}, fmt.Errorf("dSTS authority must be an https URL such as https://<authority>/dstsv2/%s", DSTSTenant)
}
if pathParts[2] != DSTSTenant {
return Info{}, fmt.Errorf("dSTS authority only accepts a single tenant %q", DSTSTenant)
}
authorityType = DSTS
tenant = DSTSTenant
default:
authorityType = AAD
tenant = pathParts[1]
}

// u.Host includes the port, if any, which is required for private cloud deployments
return Info{
Host: u.Host,
CanonicalAuthorityURI: fmt.Sprintf("https://%v/%v/", u.Host, tenant),
CanonicalAuthorityURI: authority,
AuthorityType: authorityType,
UserRealmURIPrefix: fmt.Sprintf("https://%v/common/userrealm/", u.Hostname()),
ValidateAuthority: validateAuthority,
Tenant: tenant,
InstanceDiscoveryDisabled: instanceDiscoveryDisabled,
Expand Down Expand Up @@ -525,7 +538,7 @@ func (c Client) AADInstanceDiscovery(ctx context.Context, authorityInfo Info) (I
discoveryHost = authorityInfo.Host
}

endpoint := fmt.Sprintf(instanceDiscoveryEndpoint, discoveryHost)
endpoint := fmt.Sprintf(aadInstanceDiscoveryEndpoint, discoveryHost)
err = c.Comm.JSONCall(ctx, endpoint, http.Header{}, qv, nil, &resp)
}
return resp, err
Expand Down
Loading