Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Improve authentication performance by caching on file system #1056

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions src/lib/PnP.Framework/AuthenticationManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
using Microsoft.Identity.Client.AppConfig;
using Microsoft.Identity.Client.Broker;
using Microsoft.Identity.Client.Extensibility;
using Microsoft.Identity.Client.Extensions.Msal;
using Microsoft.SharePoint.Client;
using PnP.Core.Services;
using PnP.Framework.Http;
using PnP.Framework.Utilities;
using PnP.Framework.Utilities.Cache;
using PnP.Framework.Utilities.Context;
using System;
using System.Configuration;
Expand Down Expand Up @@ -347,14 +349,14 @@ public AuthenticationManager(SecureString accessToken)
/// <param name="managedIdentityUserAssignedIdentifier">The identifier of the User Assigned Managed Identity. Can be the clientId, objectId or resourceId. Mandatory when <paramref name="managedIdentityType"/> is not SystemAssigned. Should be omitted if it is SystemAssigned.</param>
public AuthenticationManager(string endpoint, string identityHeader, ManagedIdentityType managedIdentityType = ManagedIdentityType.SystemAssigned, string managedIdentityUserAssignedIdentifier = null)
{
if(managedIdentityType != ManagedIdentityType.SystemAssigned && string.IsNullOrWhiteSpace(managedIdentityUserAssignedIdentifier))
if (managedIdentityType != ManagedIdentityType.SystemAssigned && string.IsNullOrWhiteSpace(managedIdentityUserAssignedIdentifier))
{
throw new ArgumentException($"When {nameof(managedIdentityType)} is not SystemAssigned, {nameof(managedIdentityUserAssignedIdentifier)} must be provided", nameof(managedIdentityType));
}

authenticationType = managedIdentityType == ManagedIdentityType.SystemAssigned ? ClientContextType.SystemAssignedManagedIdentity : ClientContextType.UserAssignedManagedIdentity;
this.managedIdentityType = managedIdentityType;
this.managedIdentityUserAssignedIdentifier = managedIdentityUserAssignedIdentifier;
this.managedIdentityType = managedIdentityType;
this.managedIdentityUserAssignedIdentifier = managedIdentityUserAssignedIdentifier;

// Construct the URL to call to get the token based on the type of Managed Identity in use
switch (managedIdentityType)
Expand All @@ -379,7 +381,7 @@ public AuthenticationManager(string endpoint, string identityHeader, ManagedIden
Diagnostics.Log.Debug(Constants.LOGGING_SOURCE, "Using the system assigned managed identity");
mi = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned).WithHttpClientFactory(HttpClientFactory).Build();
break;
}
}

}

Expand Down Expand Up @@ -412,12 +414,14 @@ public AuthenticationManager(string clientId, string username, SecureString pass
if (!string.IsNullOrEmpty(redirectUrl))
{
builder = builder.WithRedirectUri(redirectUrl);
}
}
builder.WithLegacyCacheCompatibility(false);
this.username = username;
this.password = password;
publicClientApplication = builder.Build();

var cacheHelper = MsalCacheHelperUtility.CreateCacheHelper();
cacheHelper?.RegisterCache(publicClientApplication.UserTokenCache);
// register tokencache if callback provided
tokenCacheCallback?.Invoke(publicClientApplication.UserTokenCache);
authenticationType = ClientContextType.AzureADCredentials;
Expand All @@ -434,7 +438,7 @@ public AuthenticationManager(string clientId, string username, SecureString pass
/// <param name="azureEnvironment">The azure environment to use. Defaults to AzureEnvironment.Production</param>
/// <param name="tokenCacheCallback">If present, after setting up the base flow for authentication this callback will be called to register a custom tokencache. See https://aka.ms/msal-net-token-cache-serialization.</param>
/// <param name="useWAM">If true, uses WAM for authentication. Works only on Windows OS</param>
public AuthenticationManager(string clientId, Action<string, int> openBrowserCallback, string tenantId = null, string successMessageHtml = null, string failureMessageHtml = null, AzureEnvironment azureEnvironment = AzureEnvironment.Production, Action<ITokenCache> tokenCacheCallback = null, bool useWAM = false) : this(clientId, Utilities.OAuth.DefaultBrowserUi.FindFreeLocalhostRedirectUri(), tenantId, azureEnvironment, tokenCacheCallback , new Utilities.OAuth.DefaultBrowserUi(openBrowserCallback, successMessageHtml, failureMessageHtml), useWAM = false)
public AuthenticationManager(string clientId, Action<string, int> openBrowserCallback, string tenantId = null, string successMessageHtml = null, string failureMessageHtml = null, AzureEnvironment azureEnvironment = AzureEnvironment.Production, Action<ITokenCache> tokenCacheCallback = null, bool useWAM = false) : this(clientId, Utilities.OAuth.DefaultBrowserUi.FindFreeLocalhostRedirectUri(), tenantId, azureEnvironment, tokenCacheCallback, new Utilities.OAuth.DefaultBrowserUi(openBrowserCallback, successMessageHtml, failureMessageHtml), useWAM = false)
{
}

Expand All @@ -452,30 +456,39 @@ public AuthenticationManager(string clientId, string redirectUrl = null, string
{
this.azureEnvironment = azureEnvironment;

var builder = PublicClientApplicationBuilder.Create(clientId).WithHttpClientFactory(HttpClientFactory);
if (useWAM && Environment.OSVersion.Platform == PlatformID.Win32NT)
PublicClientApplicationBuilder builder = PublicClientApplicationBuilder.Create(clientId).WithHttpClientFactory(HttpClientFactory); ;
builder = GetBuilderWithAuthority(builder, azureEnvironment);
if (useWAM && SharedUtilities.IsWindowsPlatform())
{
BrokerOptions brokerOptions = new(BrokerOptions.OperatingSystems.Windows)
{
Title = "Login with M365 PnP"
Title = "Login with M365 PnP",
ListOperatingSystemAccounts = true,
};
builder = builder.WithBroker(brokerOptions).WithDefaultRedirectUri().WithParentActivityOrWindow(WindowHandleUtilities.GetConsoleOrTerminalWindow).WithHttpClientFactory(HttpClientFactory);
}

builder = GetBuilderWithAuthority(builder, azureEnvironment);
builder = builder.WithBroker(brokerOptions).WithDefaultRedirectUri().WithParentActivityOrWindow(WindowHandleUtilities.GetConsoleOrTerminalWindow);

if (!string.IsNullOrEmpty(redirectUrl))
{
builder = builder.WithRedirectUri(redirectUrl);
if (!string.IsNullOrEmpty(tenantId))
{
builder = builder.WithTenantId(tenantId);
}
}
if (!string.IsNullOrEmpty(tenantId))
else
{
builder = builder.WithTenantId(tenantId);
if (!string.IsNullOrEmpty(redirectUrl))
{
builder = builder.WithRedirectUri(redirectUrl);
}
if (!string.IsNullOrEmpty(tenantId))
{
builder = builder.WithTenantId(tenantId);
}
this.customWebUi = customWebUi;
}
builder.WithLegacyCacheCompatibility(false);
publicClientApplication = builder.Build();

this.customWebUi = customWebUi;
var cacheHelper = MsalCacheHelperUtility.CreateCacheHelper();
cacheHelper?.RegisterCache(publicClientApplication.UserTokenCache);

// register tokencache if callback provided
tokenCacheCallback?.Invoke(publicClientApplication.UserTokenCache);
Expand Down Expand Up @@ -524,6 +537,9 @@ public AuthenticationManager(string clientId, string tenantId, Func<DeviceCodeRe
builder.WithLegacyCacheCompatibility(false);
publicClientApplication = builder.Build();

var cacheHelper = MsalCacheHelperUtility.CreateCacheHelper();
cacheHelper?.RegisterCache(publicClientApplication.UserTokenCache);

// register tokencache if callback provided
tokenCacheCallback?.Invoke(publicClientApplication.UserTokenCache);

Expand Down Expand Up @@ -831,7 +847,7 @@ public async Task<string> GetAccessTokenAsync(string[] scopes, CancellationToken
{
AuthenticationResult authResult = null;


Diagnostics.Log.Debug("GetAccessTokenAsync", $"Authentication type: {authenticationType}");

switch (authenticationType)
Expand Down Expand Up @@ -954,7 +970,7 @@ public async Task<string> GetAccessTokenAsync(string[] scopes, CancellationToken
// If it is a Uri, we're going to assume the audience is the root part of the Uri, i.e. tenant.sharepoint.com
var audienceUri = new Uri(scopes.FirstOrDefault(s => Uri.IsWellFormedUriString(s, UriKind.Absolute)) ?? $"https://{GetGraphEndPoint()}");
return GetManagedIdentityToken($"{audienceUri.Scheme}://{audienceUri.Authority}");
}
}
case ClientContextType.PnPCoreSdk:
{
return await this.authenticationProvider.GetAccessTokenAsync(uri, scopes).ConfigureAwait(false);
Expand Down Expand Up @@ -1490,7 +1506,7 @@ public static string GetACSEndPoint(AzureEnvironment environment)
AzureEnvironment.Production => "accesscontrol.windows.net",
AzureEnvironment.Germany => "microsoftonline.de",
AzureEnvironment.China => "accesscontrol.chinacloudapi.cn",
AzureEnvironment.USGovernment => "accesscontrol.windows.net",
AzureEnvironment.USGovernment => "accesscontrol.windows.net",
AzureEnvironment.USGovernmentHigh => "microsoftonline.us",
AzureEnvironment.USGovernmentDoD => "microsoftonline.us",
AzureEnvironment.PPE => "windows-ppe.net",
Expand Down Expand Up @@ -1928,7 +1944,7 @@ public ConfidentialClientApplicationBuilder GetBuilderWithAuthority(Confidential
{
switch (azureEnvironment)
{
case AzureEnvironment.USGovernment:
case AzureEnvironment.USGovernment:
{
builder = builder.WithAuthority(AzureCloudInstance.AzurePublic, AadAuthorityAudience.AzureAdMyOrg);
break;
Expand Down
86 changes: 86 additions & 0 deletions src/lib/PnP.Framework/Utilities/Cache/MsalCacheHelperUtility.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using Microsoft.Identity.Client.Extensions.Msal;
using System;
using System.Collections.Generic;
using System.IO;

namespace PnP.Framework.Utilities.Cache
{
public class MsalCacheHelperUtility
{

private static MsalCacheHelper MsalCacheHelper;
private static readonly object ObjectLock = new();

private static class Config
{
// Cache settings
public const string CacheFileName = "m365pnpmsal.cache";
public readonly static string CacheDir = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), ".M365PnPAuthService");

public const string KeyChainServiceName = "M365.PnP.Framework";
public const string KeyChainAccountName = "M365PnPAuthCache";

public const string LinuxKeyRingSchema = "com.m365.pnp.auth.tokencache";
public const string LinuxKeyRingCollection = MsalCacheHelper.LinuxKeyRingDefaultCollection;
public const string LinuxKeyRingLabel = "MSAL token cache for M365 PnP Framework.";
public static readonly KeyValuePair<string, string> LinuxKeyRingAttr1 = new KeyValuePair<string, string>("Version", "1");
public static readonly KeyValuePair<string, string> LinuxKeyRingAttr2 = new KeyValuePair<string, string>("Product", "M365PnPAuth");
}

public static MsalCacheHelper CreateCacheHelper()
{
if (MsalCacheHelper == null)
{
lock (ObjectLock)
{
if (MsalCacheHelper == null)
{
StorageCreationProperties storageProperties;

try
{
storageProperties = new StorageCreationPropertiesBuilder(
Config.CacheFileName,
Config.CacheDir)
.WithLinuxKeyring(
Config.LinuxKeyRingSchema,
Config.LinuxKeyRingCollection,
Config.LinuxKeyRingLabel,
Config.LinuxKeyRingAttr1,
Config.LinuxKeyRingAttr2)
.WithMacKeyChain(
Config.KeyChainServiceName,
Config.KeyChainAccountName)
.Build();

var cacheHelper = MsalCacheHelper.CreateAsync(storageProperties).ConfigureAwait(false).GetAwaiter().GetResult();

cacheHelper.VerifyPersistence();
MsalCacheHelper = cacheHelper;

}
catch (MsalCachePersistenceException)
{
// do not use the same file name so as not to overwrite the encrypted version
storageProperties = new StorageCreationPropertiesBuilder(
Config.CacheFileName + ".plaintext",
Config.CacheDir)
.WithUnprotectedFile()
.Build();

var cacheHelper = MsalCacheHelper.CreateAsync(storageProperties).ConfigureAwait(false).GetAwaiter().GetResult();
cacheHelper.VerifyPersistence();

MsalCacheHelper = cacheHelper;
}
catch
{
MsalCacheHelper = null;
}
}
}
}
return MsalCacheHelper;
}
}
}
Loading