diff --git a/apps/internal/base/storage/storage.go b/apps/internal/base/storage/storage.go index 60d85474..799ae87e 100644 --- a/apps/internal/base/storage/storage.go +++ b/apps/internal/base/storage/storage.go @@ -177,6 +177,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces target := strings.Join(tokenResponse.GrantedScopes.Slice, scopeSeparator) cachedAt := time.Now() authnSchemeKeyID := authParameters.AuthnScheme.KeyID() + var account shared.Account if len(tokenResponse.RefreshToken) > 0 { @@ -200,6 +201,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces tokenResponse.TokenType, authnSchemeKeyID, ) + // Since we have a valid access token, cache it before moving on. if err := accessToken.Validate(); err == nil { if err := m.writeAccessToken(accessToken); err != nil { @@ -237,6 +239,7 @@ func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse acces } AppMetaData := NewAppMetaData(tokenResponse.FamilyID, clientID, environment) + if err := m.writeAppMetaData(AppMetaData); err != nil { return shared.Account{}, err } @@ -263,11 +266,11 @@ func (m *Manager) aadMetadataFromCache(ctx context.Context, authorityInfo author } func (m *Manager) aadMetadata(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) { - m.aadCacheMu.Lock() - defer m.aadCacheMu.Unlock() if m.requests == nil { return authority.InstanceDiscoveryMetadata{}, fmt.Errorf("httpclient in oauth instance for fetching metadata is nil") } + m.aadCacheMu.Lock() + defer m.aadCacheMu.Unlock() discoveryResponse, err := m.requests.AADInstanceDiscovery(ctx, authorityInfo) if err != nil { return authority.InstanceDiscoveryMetadata{}, err diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index 0dfdd0b1..afa028db 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -238,7 +238,7 @@ func (client Client) AcquireToken(ctx context.Context, resource string, options // ignore cached access tokens when given claims if o.claims == "" { if cacheManager == nil { - return base.AuthResult{}, fmt.Errorf("cache instance is nil") + return base.AuthResult{}, errors.New("cache instance is nil") } storageTokenResponse, err := cacheManager.Read(ctx, fakeAuthParams) if err != nil { diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 73fa4036..e19a66c0 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -159,10 +159,7 @@ func Test_SystemAssigned_Returns_Token_Success(t *testing.T) { if err != nil { t.Fatal(err) } - if localUrl == nil { - t.Fatalf("url request is not on %s got %s", testCase.endpoint, localUrl) - } - if !strings.HasPrefix(localUrl.String(), testCase.endpoint) { + if localUrl == nil || !strings.HasPrefix(localUrl.String(), testCase.endpoint) { t.Fatalf("url request is not on %s got %s", testCase.endpoint, localUrl) } if testCase.miType.value() != systemAssignedManagedIdentity { diff --git a/apps/tests/devapps/managedidentity/managedidentity_sample.go b/apps/tests/devapps/managedidentity/managedidentity_sample.go index baa3caee..ce4ef38f 100644 --- a/apps/tests/devapps/managedidentity/managedidentity_sample.go +++ b/apps/tests/devapps/managedidentity/managedidentity_sample.go @@ -19,17 +19,6 @@ func runIMDSSystemAssigned() { log.Fatal(err) } fmt.Println("token expire at : ", result.ExpiresOn) - fmt.Println("token source : ", result.Metadata.TokenSource) - miSystemAssignedCache, err := mi.New(mi.SystemAssigned()) - if err != nil { - log.Fatal(err) - } - cachedResult, err := miSystemAssignedCache.AcquireToken(context.TODO(), "https://management.azure.com") - if err != nil { - log.Fatal(err) - } - fmt.Println("token expire at : ", cachedResult.ExpiresOn) - fmt.Println("token source : ", cachedResult.Metadata.TokenSource) }