From e62845017091cdd02b69c22d6a60d5322468b5b3 Mon Sep 17 00:00:00 2001 From: Brent Schmaltz Date: Thu, 31 Aug 2023 05:58:32 -0700 Subject: [PATCH] Reduce allocations and transformations when creating a token. --- .../CreateTokenTests.cs | 3 +- .../JsonWebTokenHandler.CreateToken.cs | 1383 +++++++++++++++++ .../JsonWebTokenHandler.cs | 958 +----------- .../JwtTokenUtilities.cs | 82 +- .../AsymmetricAdapter.cs | 183 ++- .../AsymmetricSignatureProvider.cs | 74 +- .../Base64UrlEncoder.cs | 64 +- .../Base64UrlEncoding.cs | 5 + .../RsaCryptoServiceProviderProxy.cs | 10 + .../SignatureProvider.cs | 11 + .../SupportedAlgorithms.cs | 41 + .../SymmetricSignatureProvider.cs | 70 +- .../TokenUtilities.cs | 2 +- .../JsonWebTokenHandlerTests.cs | 153 +- .../IdentityComparer.cs | 88 +- .../Base64UrlEncodingTests.cs | 311 ++++ .../CryptoProviderFactoryTests.cs | 10 +- .../IdentityComparerTests.cs | 2 +- .../SignatureProviderTests.cs | 410 ++++- 19 files changed, 2794 insertions(+), 1066 deletions(-) create mode 100644 src/Microsoft.IdentityModel.JsonWebTokens/JsonWebTokenHandler.CreateToken.cs create mode 100644 test/Microsoft.IdentityModel.Tokens.Tests/Base64UrlEncodingTests.cs diff --git a/benchmark/Microsoft.IdentityModel.Benchmarks/CreateTokenTests.cs b/benchmark/Microsoft.IdentityModel.Benchmarks/CreateTokenTests.cs index 7f753ba26c..85acf48301 100644 --- a/benchmark/Microsoft.IdentityModel.Benchmarks/CreateTokenTests.cs +++ b/benchmark/Microsoft.IdentityModel.Benchmarks/CreateTokenTests.cs @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +using System; using BenchmarkDotNet.Attributes; using Microsoft.IdentityModel.JsonWebTokens; using Microsoft.IdentityModel.Tokens; namespace Microsoft.IdentityModel.Benchmarks { - [HideColumns("Type", "Job", "WarmupCount", "LaunchCount")] [MemoryDiagnoser] public class CreateTokenTests { @@ -17,6 +17,7 @@ public class CreateTokenTests [GlobalSetup] public void Setup() { + DateTime now = DateTime.UtcNow; _jsonWebTokenHandler = new JsonWebTokenHandler(); _tokenDescriptor = new SecurityTokenDescriptor { diff --git a/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebTokenHandler.CreateToken.cs b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebTokenHandler.CreateToken.cs new file mode 100644 index 0000000000..84be41f744 --- /dev/null +++ b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebTokenHandler.CreateToken.cs @@ -0,0 +1,1383 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Security.Claims; +using System.Text; +using System.Text.Encodings.Web; +using System.Text.Json; +using Microsoft.IdentityModel.Abstractions; +using Microsoft.IdentityModel.Logging; +using Microsoft.IdentityModel.Tokens; +using JsonPrimitives = Microsoft.IdentityModel.Tokens.Json.JsonSerializerPrimitives; +using TokenLogMessages = Microsoft.IdentityModel.Tokens.LogMessages; + +namespace Microsoft.IdentityModel.JsonWebTokens +{ + /// + /// A designed for creating and validating Json Web Tokens. + /// See: https://datatracker.ietf.org/doc/html/rfc7519 and http://www.rfc-editor.org/info/rfc7515. + /// + /// This partial class is focused on TokenCreation. + public partial class JsonWebTokenHandler : TokenHandler + { + /// + /// Creates an unsigned JWS (Json Web Signature). + /// + /// A string containing JSON which represents the JWT token payload. + /// if is null. + /// A JWS in Compact Serialization Format. + public virtual string CreateToken( + string payload) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + return CreateToken( + payload, + null, + null, + null, + null, + null, + null); + } + + /// + /// Creates an unsigned JWS (Json Web Signature). + /// + /// A string containing JSON which represents the JWT token payload. + /// Defines the dictionary containing any custom header claims that need to be added to the JWT token header. + /// if is null. + /// if is null. + /// A JWS in Compact Serialization Format. + public virtual string CreateToken( + string payload, + IDictionary additionalHeaderClaims) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + _ = additionalHeaderClaims ?? throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); + + return CreateToken(payload, + null, + null, + null, + additionalHeaderClaims, + null, + null); + } + + /// + /// Creates a JWS (Json Web Signature). + /// + /// A string containing JSON which represents the JWT token payload. + /// Defines the security key and algorithm that will be used to sign the JWS. + /// if is null. + /// if is null. + /// A JWS in Compact Serialization Format. + public virtual string CreateToken( + string payload, + SigningCredentials signingCredentials) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + _ = signingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); + + return CreateToken( + payload, + signingCredentials, + null, + null, + null, + null, + null); + } + + /// + /// Creates a JWS (Json Web Signature). + /// + /// A string containing JSON which represents the JWT token payload. + /// Defines the security key and algorithm that will be used to sign the JWS. + /// Defines the dictionary containing any custom header claims that need to be added to the JWT token header. + /// if is null. + /// if is null. + /// if is null. + /// if , + /// , , and/or + /// are present inside of . + /// A JWS in Compact Serialization Format. + public virtual string CreateToken( + string payload, + SigningCredentials signingCredentials, + IDictionary additionalHeaderClaims) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + _ = signingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); + _ = additionalHeaderClaims ?? throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); + + return CreateToken( + payload, + signingCredentials, + null, + null, + additionalHeaderClaims, + null, + null); + } + + /// + /// Creates a JWt that can be a JWS or JWE. + /// + /// A that contains details of contents of the token. + /// A JWT in Compact Serialization Format. + public virtual string CreateToken(SecurityTokenDescriptor tokenDescriptor) + { + _ = tokenDescriptor ?? throw LogHelper.LogArgumentNullException(nameof(tokenDescriptor)); + + if (LogHelper.IsEnabled(EventLogLevel.Warning)) + { + if ((tokenDescriptor.Subject == null || !tokenDescriptor.Subject.Claims.Any()) + && (tokenDescriptor.Claims == null || !tokenDescriptor.Claims.Any())) + LogHelper.LogWarning( + LogMessages.IDX14114, LogHelper.MarkAsNonPII(nameof(SecurityTokenDescriptor)), LogHelper.MarkAsNonPII(nameof(SecurityTokenDescriptor.Subject)), LogHelper.MarkAsNonPII(nameof(SecurityTokenDescriptor.Claims))); + } + + if (tokenDescriptor.AdditionalHeaderClaims?.Count > 0 && tokenDescriptor.AdditionalHeaderClaims.Keys.Intersect(JwtTokenUtilities.DefaultHeaderParameters, StringComparer.OrdinalIgnoreCase).Any()) + throw LogHelper.LogExceptionMessage( + new SecurityTokenException( + LogHelper.FormatInvariant( + LogMessages.IDX14116, + LogHelper.MarkAsNonPII(nameof(tokenDescriptor.AdditionalHeaderClaims)), + LogHelper.MarkAsNonPII(string.Join(", ", JwtTokenUtilities.DefaultHeaderParameters))))); + + if (tokenDescriptor.AdditionalInnerHeaderClaims?.Count > 0 && tokenDescriptor.AdditionalInnerHeaderClaims.Keys.Intersect(JwtTokenUtilities.DefaultHeaderParameters, StringComparer.OrdinalIgnoreCase).Any()) + throw LogHelper.LogExceptionMessage( + new SecurityTokenException( + LogHelper.FormatInvariant( + LogMessages.IDX14116, + LogHelper.MarkAsNonPII(nameof(tokenDescriptor.AdditionalInnerHeaderClaims)), + LogHelper.MarkAsNonPII(string.Join(", ", JwtTokenUtilities.DefaultHeaderParameters))))); + + return CreateToken( + tokenDescriptor, + SetDefaultTimesOnTokenCreation, + TokenLifetimeInMinutes); + } + + internal static string CreateToken( + SecurityTokenDescriptor tokenDescriptor, + bool setdefaultTimesOnTokenCreation, + int tokenLifetimeInMinutes) + { + // The form of a JWS is: Base64UrlEncoding(UTF8(Header)) | . | Base64UrlEncoding(Payload) | . | Base64UrlEncoding(Signature) + // Where the Header is specifically the UTF8 bytes of the JSON, whereas the Payload encoding is not specified, but UTF8 is used by everyone. + // The signature is over ASCII(Utf8Bytes(Base64UrlEncoding(Header) | . | Base64UrlEncoding(Payload))) + // Since it is not known how large the JWS will be, a MemoryStream is used. + // An ArrayBufferWriter was benchmarked, while slightly faster, more memory is used and different code would be needed for 461+ and net6.0+ + // + // net6.0 has added api's that allow passing an allocated buffer when calculating the signature, so ArrayPool.Rent can be used. + + using (MemoryStream utf8ByteMemoryStream = new()) + { + Utf8JsonWriter writer = null; + char[] encodedChars = null; + byte[] asciiBytes = null; + byte[] signatureBytes = null; + + try + { + writer = new(utf8ByteMemoryStream, new JsonWriterOptions { Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping }); + + WriteJwsHeader( + ref writer, + tokenDescriptor.SigningCredentials, + tokenDescriptor.EncryptingCredentials, + tokenDescriptor.AdditionalHeaderClaims, + tokenDescriptor.AdditionalInnerHeaderClaims, + tokenDescriptor.TokenType); + + // mark length of jwt header + int headerLength = (int)utf8ByteMemoryStream.Length; + + // reset the writer and write the payload + writer.Reset(); + WriteJwsPayload( + ref writer, + tokenDescriptor, + setdefaultTimesOnTokenCreation, + tokenLifetimeInMinutes); + + // mark end of payload + int payloadEnd = (int)utf8ByteMemoryStream.Length; + int signatureSize = 0; + if (tokenDescriptor.SigningCredentials != null) + signatureSize = SupportedAlgorithms.GetMaxByteCount(tokenDescriptor.SigningCredentials.Algorithm); + + int encodedBufferSize = (payloadEnd + 4 + signatureSize) / 3 * 4; + encodedChars = ArrayPool.Shared.Rent(encodedBufferSize + 4); + + // Base64UrlEncode the Header + int sizeOfEncodedHeader = Base64UrlEncoder.Encode(utf8ByteMemoryStream.GetBuffer().AsSpan(0, headerLength), encodedChars); + encodedChars[sizeOfEncodedHeader] = '.'; + int sizeOfEncodedPayload = Base64UrlEncoder.Encode(utf8ByteMemoryStream.GetBuffer().AsSpan(headerLength, payloadEnd - headerLength), encodedChars.AsSpan(sizeOfEncodedHeader + 1)); + // encodeChars => 'EncodedHeader.EncodedPayload' + + // Get ASCII Bytes of 'EncodedHeader.EncodedPayload' which is used to calculate the signature + asciiBytes = ArrayPool.Shared.Rent(Encoding.ASCII.GetMaxByteCount(encodedBufferSize)); + int sizeOfEncodedHeaderAndPayloadAsciiBytes + = Encoding.ASCII.GetBytes(encodedChars, 0, sizeOfEncodedHeader + sizeOfEncodedPayload + 1, asciiBytes, 0); + + encodedChars[sizeOfEncodedHeader + sizeOfEncodedPayload + 1] = '.'; + // encodedChars => 'EncodedHeader.EncodedPayload.' + + int sizeOfEncodedSignature = 0; + if (tokenDescriptor.SigningCredentials != null) + { +#if NET6_0_OR_GREATER + signatureBytes = ArrayPool.Shared.Rent(signatureSize); + bool signatureSucceeded = JwtTokenUtilities.CreateSignature( + asciiBytes.AsSpan(0, sizeOfEncodedHeaderAndPayloadAsciiBytes), + signatureBytes, + tokenDescriptor.SigningCredentials, + out int signatureLength); +#else + signatureBytes = JwtTokenUtilities.CreateEncodedSignature(asciiBytes, 0, sizeOfEncodedHeaderAndPayloadAsciiBytes, tokenDescriptor.SigningCredentials); + int signatureLength = signatureBytes.Length; +#endif + sizeOfEncodedSignature = Base64UrlEncoder.Encode(signatureBytes.AsSpan(0, signatureLength), encodedChars.AsSpan(sizeOfEncodedHeader + sizeOfEncodedPayload + 2)); + } + + if (tokenDescriptor.EncryptingCredentials != null) + { + return EncryptToken( + Encoding.UTF8.GetBytes(encodedChars, 0, sizeOfEncodedHeader + sizeOfEncodedPayload + sizeOfEncodedSignature + 2), + tokenDescriptor.EncryptingCredentials, + tokenDescriptor.CompressionAlgorithm, + tokenDescriptor.AdditionalHeaderClaims, + tokenDescriptor.TokenType); + } + else + { + return encodedChars.AsSpan(0, sizeOfEncodedHeader + sizeOfEncodedPayload + sizeOfEncodedSignature + 2).ToString(); + } + } + finally + { + if (encodedChars is not null) + ArrayPool.Shared.Return(encodedChars); +#if NET6_0_OR_GREATER + if (signatureBytes is not null) + ArrayPool.Shared.Return(signatureBytes); +#endif + if (asciiBytes is not null) + ArrayPool.Shared.Return(asciiBytes); + + writer?.Dispose(); + } + } + } + + /// + /// Creates a JWE (Json Web Encryption). + /// + /// A string containing JSON which represents the JWT token payload. + /// Defines the security key and algorithm that will be used to encrypt the JWT. + /// A JWE in compact serialization format. + public virtual string CreateToken( + string payload, + EncryptingCredentials encryptingCredentials) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + _ = encryptingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + + return CreateToken( + payload, + null, + encryptingCredentials, + null, + null, + null, + null); + } + + /// + /// Creates a JWE (Json Web Encryption). + /// + /// A string containing JSON which represents the JWT token payload. + /// Defines the security key and algorithm that will be used to encrypt the JWT. + /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. + /// if is null. + /// if is null. + /// if is null. + /// if , + /// , , and/or + /// are present inside of . + /// A JWS in Compact Serialization Format. + public virtual string CreateToken( + string payload, + EncryptingCredentials encryptingCredentials, + IDictionary additionalHeaderClaims) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + _ = encryptingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + _ = additionalHeaderClaims ?? throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); + + return CreateToken( + payload, + null, + encryptingCredentials, + null, + additionalHeaderClaims, + null, + null); + } + + /// + /// Creates a JWE (Json Web Encryption). + /// + /// A string containing JSON which represents the JWT token payload. + /// Defines the security key and algorithm that will be used to sign the JWT. + /// Defines the security key and algorithm that will be used to encrypt the JWT. + /// if is null. + /// if is null. + /// if is null. + /// A JWE in compact serialization format. + public virtual string CreateToken( + string payload, + SigningCredentials signingCredentials, + EncryptingCredentials encryptingCredentials) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + _ = signingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); + _ = encryptingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + + return CreateToken( + payload, + signingCredentials, + encryptingCredentials, + null, + null, + null, + null); + } + + /// + /// Creates a JWE (Json Web Encryption). + /// + /// A string containing JSON which represents the JWT token payload. + /// Defines the security key and algorithm that will be used to sign the JWT. + /// Defines the security key and algorithm that will be used to encrypt the JWT. + /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. + /// if is null. + /// if is null. + /// if is null. + /// if is null. + /// if , + /// , , and/or + /// are present inside of . + /// A JWE in compact serialization format. + public virtual string CreateToken( + string payload, + SigningCredentials signingCredentials, + EncryptingCredentials encryptingCredentials, + IDictionary additionalHeaderClaims) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + _ = signingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); + _ = encryptingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + _ = additionalHeaderClaims ?? throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); + + return CreateToken( + payload, + signingCredentials, + encryptingCredentials, + null, + additionalHeaderClaims, + null, + null); + } + + /// + /// Creates a JWE (Json Web Encryption). + /// + /// A string containing JSON which represents the JWT token payload. + /// Defines the security key and algorithm that will be used to encrypt the JWT. + /// Defines the compression algorithm that will be used to compress the JWT token payload. + /// A JWE in compact serialization format. + public virtual string CreateToken( + string payload, + EncryptingCredentials encryptingCredentials, + string compressionAlgorithm) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + if (string.IsNullOrEmpty(compressionAlgorithm)) + throw LogHelper.LogArgumentNullException(nameof(compressionAlgorithm)); + + _ = encryptingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + + return CreateToken( + payload, + null, + encryptingCredentials, + compressionAlgorithm, + null, + null, + null); + } + + /// + /// Creates a JWE (Json Web Encryption). + /// + /// A string containing JSON which represents the JWT token payload. + /// Defines the security key and algorithm that will be used to sign the JWT. + /// Defines the security key and algorithm that will be used to encrypt the JWT. + /// Defines the compression algorithm that will be used to compress the JWT token payload. + /// if is null. + /// if is null. + /// if is null. + /// if is null. + /// A JWE in compact serialization format. + public virtual string CreateToken( + string payload, + SigningCredentials signingCredentials, + EncryptingCredentials encryptingCredentials, + string compressionAlgorithm) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + if (string.IsNullOrEmpty(compressionAlgorithm)) + throw LogHelper.LogArgumentNullException(nameof(compressionAlgorithm)); + + _ = signingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); + _ = encryptingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + + return CreateToken( + payload, + signingCredentials, + encryptingCredentials, + compressionAlgorithm, + null, + null, + null); + } + + /// + /// Creates a JWE (Json Web Encryption). + /// + /// A string containing JSON which represents the JWT token payload. + /// Defines the security key and algorithm that will be used to sign the JWT. + /// Defines the security key and algorithm that will be used to encrypt the JWT. + /// Defines the compression algorithm that will be used to compress the JWT token payload. + /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. + /// Defines the dictionary containing any custom header claims that need to be added to the inner JWT token header. + /// if is null. + /// if is null. + /// if is null. + /// if is null. + /// if is null. + /// if , + /// , , and/or + /// are present inside of . + /// A JWE in compact serialization format. + public virtual string CreateToken( + string payload, + SigningCredentials signingCredentials, + EncryptingCredentials encryptingCredentials, + string compressionAlgorithm, + IDictionary additionalHeaderClaims, + IDictionary additionalInnerHeaderClaims) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + if (string.IsNullOrEmpty(compressionAlgorithm)) + throw LogHelper.LogArgumentNullException(nameof(compressionAlgorithm)); + + _ = signingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); + _ = encryptingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + _ = additionalHeaderClaims ?? throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); + _ = additionalInnerHeaderClaims ?? throw LogHelper.LogArgumentNullException(nameof(additionalInnerHeaderClaims)); + + return CreateToken( + payload, + signingCredentials, + encryptingCredentials, + compressionAlgorithm, + additionalHeaderClaims, + additionalInnerHeaderClaims, + null); + } + + /// + /// Creates a JWE (Json Web Encryption). + /// + /// A string containing JSON which represents the JWT token payload. + /// Defines the security key and algorithm that will be used to sign the JWT. + /// Defines the security key and algorithm that will be used to encrypt the JWT. + /// Defines the compression algorithm that will be used to compress the JWT token payload. + /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. + /// if is null. + /// if is null. + /// if is null. + /// if is null. + /// if is null. + /// if , + /// , , and/or + /// are present inside of . + /// A JWE in compact serialization format. + public virtual string CreateToken( + string payload, + SigningCredentials signingCredentials, + EncryptingCredentials encryptingCredentials, + string compressionAlgorithm, + IDictionary additionalHeaderClaims) + { + if (string.IsNullOrEmpty(payload)) + throw LogHelper.LogArgumentNullException(nameof(payload)); + + if (string.IsNullOrEmpty(compressionAlgorithm)) + throw LogHelper.LogArgumentNullException(nameof(compressionAlgorithm)); + + _ = signingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); + _ = encryptingCredentials ?? throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + _ = additionalHeaderClaims ?? throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); + + return CreateToken( + payload, + signingCredentials, + encryptingCredentials, + compressionAlgorithm, + additionalHeaderClaims, + null, + null); + } + + internal static string CreateToken + ( + string payload, + SigningCredentials signingCredentials, + EncryptingCredentials encryptingCredentials, + string compressionAlgorithm, + IDictionary additionalHeaderClaims, + IDictionary additionalInnerHeaderClaims, + string tokenType) + { + using (MemoryStream utf8ByteMemoryStream = new ()) + { + Utf8JsonWriter writer = null; + char[] encodedChars = null; + byte[] asciiBytes = null; + byte[] signatureBytes = null; + byte[] payloadBytes = null; + + try + { + writer = new Utf8JsonWriter(utf8ByteMemoryStream, new JsonWriterOptions { Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping }); + + WriteJwsHeader( + ref writer, + signingCredentials, + encryptingCredentials, + additionalHeaderClaims, + additionalInnerHeaderClaims, + null); + + // mark length of jwt header + int headerLength = (int)utf8ByteMemoryStream.Length; + int signatureSize = 0; + if (signingCredentials != null) + signatureSize = SupportedAlgorithms.GetMaxByteCount(signingCredentials.Algorithm); + + payloadBytes = ArrayPool.Shared.Rent(Encoding.UTF8.GetMaxByteCount(payload.Length)); + int payloadSize = Encoding.UTF8.GetBytes(payload, 0, payload.Length, payloadBytes, 0); + + int encodedBufferSize = (headerLength + payloadSize + 4 + signatureSize) / 3 * 4; + encodedChars = ArrayPool.Shared.Rent(encodedBufferSize + 4); + + int sizeOfEncodedHeader = Base64UrlEncoder.Encode(utf8ByteMemoryStream.GetBuffer().AsSpan(0, headerLength), encodedChars); + encodedChars[sizeOfEncodedHeader] = '.'; + + int sizeOfEncodedPayload = Base64UrlEncoder.Encode(payloadBytes.AsSpan(0, payloadSize), encodedChars.AsSpan(sizeOfEncodedHeader + 1)); + // encodeChars => 'EncodedHeader.EncodedPayload' + + // Get ASCII Bytes of 'EncodedHeader.EncodedPayload' which is used to calculate the signature + asciiBytes = ArrayPool.Shared.Rent(Encoding.ASCII.GetMaxByteCount(encodedBufferSize)); + int sizeOfEncodedHeaderAndPayloadAsciiBytes + = Encoding.ASCII.GetBytes(encodedChars, 0, sizeOfEncodedHeader + sizeOfEncodedPayload + 1, asciiBytes, 0); + + encodedChars[sizeOfEncodedHeader + sizeOfEncodedPayload + 1] = '.'; + // encodedChars => 'EncodedHeader.EncodedPayload.' + + int sizeOfEncodedSignature = 0; + if (signingCredentials != null) + { +#if NET6_0_OR_GREATER + signatureBytes = ArrayPool.Shared.Rent(signatureSize); + bool signatureSucceeded = JwtTokenUtilities.CreateSignature( + asciiBytes.AsSpan(0, sizeOfEncodedHeaderAndPayloadAsciiBytes), + signatureBytes, + signingCredentials, + out int signatureLength); +#else + signatureBytes = JwtTokenUtilities.CreateEncodedSignature(asciiBytes, 0, sizeOfEncodedHeaderAndPayloadAsciiBytes, signingCredentials); + int signatureLength = signatureBytes.Length; +#endif + sizeOfEncodedSignature = Base64UrlEncoder.Encode(signatureBytes.AsSpan(0, signatureLength), encodedChars.AsSpan(sizeOfEncodedHeader + sizeOfEncodedPayload + 2)); + } + + if (encryptingCredentials != null) + { + return EncryptToken( + Encoding.UTF8.GetBytes(encodedChars, 0, sizeOfEncodedHeader + sizeOfEncodedPayload + sizeOfEncodedSignature + 2), + encryptingCredentials, + compressionAlgorithm, + additionalHeaderClaims, + tokenType); + } + else + { + return encodedChars.AsSpan(0, sizeOfEncodedHeader + sizeOfEncodedPayload + sizeOfEncodedSignature + 2).ToString(); + } + } + finally + { + if (encodedChars is not null) + ArrayPool.Shared.Return(encodedChars); +#if NET6_0_OR_GREATER + if (signatureBytes is not null) + ArrayPool.Shared.Return(signatureBytes); +#endif + if (asciiBytes is not null) + ArrayPool.Shared.Return(asciiBytes); + + if (payloadBytes is not null) + ArrayPool.Shared.Return(payloadBytes); + + writer?.Dispose(); + } + } + } + + /// + /// A can contain claims from multiple locations. + /// This method consolidates the claims and adds default times {exp, iat, nbf} if needed. + /// + /// + /// + /// + /// + /// A dictionary of claims. + internal static void WriteJwsPayload( + ref Utf8JsonWriter writer, + SecurityTokenDescriptor tokenDescriptor, + bool setDefaultTimesOnTokenCreation, + int tokenLifetimeInMinutes) + { + bool audienceChecked = false; + bool audienceSet = false; + bool issuerChecked = false; + bool issuerSet = false; + bool expChecked = false; + bool expSet = false; + bool iatChecked = false; + bool iatSet = false; + bool nbfChecked = false; + bool nbfSet = false; + + writer.WriteStartObject(); + + if (!string.IsNullOrEmpty(tokenDescriptor.Audience)) + { + audienceSet = true; + writer.WritePropertyName(JwtPayloadUtf8Bytes.Aud); + writer.WriteStringValue(tokenDescriptor.Audience); + } + + if (!string.IsNullOrEmpty(tokenDescriptor.Issuer)) + { + issuerSet = true; + writer.WritePropertyName(JwtPayloadUtf8Bytes.Iss); + writer.WriteStringValue(tokenDescriptor.Issuer); + } + + if (tokenDescriptor.Expires.HasValue) + { + expSet = true; + writer.WritePropertyName(JwtPayloadUtf8Bytes.Exp); + writer.WriteNumberValue(EpochTime.GetIntDate(tokenDescriptor.Expires.Value)); + } + + if (tokenDescriptor.IssuedAt.HasValue) + { + iatSet = true; + writer.WritePropertyName(JwtPayloadUtf8Bytes.Iat); + writer.WriteNumberValue(EpochTime.GetIntDate(tokenDescriptor.IssuedAt.Value)); + } + + if (tokenDescriptor.NotBefore.HasValue) + { + nbfSet = true; + writer.WritePropertyName(JwtPayloadUtf8Bytes.Nbf); + writer.WriteNumberValue(EpochTime.GetIntDate(tokenDescriptor.NotBefore.Value)); + } + + // Duplicates are resolved according to the following priority: + // SecurityTokenDescriptor.{Audience, Issuer, Expires, IssuedAt, NotBefore}, SecurityTokenDescriptor.Claims, SecurityTokenDescriptor.Subject.Claims + // SecurityTokenDescriptor.Claims are KeyValuePairs, whereas SeSecurityTokenDescriptor.Subject.Claims are System.Security.Claims.Claim and are processed differently. + + if (tokenDescriptor.Claims != null && tokenDescriptor.Claims.Count > 0) + { + foreach (KeyValuePair kvp in tokenDescriptor.Claims) + { + if (!audienceChecked && kvp.Key.Equals(JwtRegisteredClaimNames.Aud, StringComparison.Ordinal)) + { + audienceChecked = true; + if (audienceSet) + { + if (LogHelper.IsEnabled(EventLogLevel.Informational)) + LogHelper.LogInformation(LogHelper.FormatInvariant(LogMessages.IDX14113, LogHelper.MarkAsNonPII(nameof(tokenDescriptor.Audience)))); + + continue; + } + + audienceSet = true; + } + + if (!issuerChecked && kvp.Key.Equals(JwtRegisteredClaimNames.Iss, StringComparison.Ordinal)) + { + issuerChecked = true; + if (issuerSet) + { + if (LogHelper.IsEnabled(EventLogLevel.Informational)) + LogHelper.LogInformation(LogHelper.FormatInvariant(LogMessages.IDX14113, LogHelper.MarkAsNonPII(nameof(tokenDescriptor.Issuer)))); + + continue; + } + + issuerSet = true; + } + + if (!expChecked && kvp.Key.Equals(JwtRegisteredClaimNames.Exp, StringComparison.Ordinal)) + { + expChecked = true; + if (expSet) + { + if (LogHelper.IsEnabled(EventLogLevel.Informational)) + LogHelper.LogInformation(LogHelper.FormatInvariant(LogMessages.IDX14113, LogHelper.MarkAsNonPII(nameof(tokenDescriptor.Expires)))); + + continue; + } + + expSet = true; + } + + if (!iatChecked && kvp.Key.Equals(JwtRegisteredClaimNames.Iat, StringComparison.Ordinal)) + { + iatChecked = true; + if (iatSet) + { + if (LogHelper.IsEnabled(EventLogLevel.Informational)) + LogHelper.LogInformation(LogHelper.FormatInvariant(LogMessages.IDX14113, LogHelper.MarkAsNonPII(nameof(tokenDescriptor.Expires)))); + + continue; + } + + iatSet = true; + } + + if (!nbfChecked && kvp.Key.Equals(JwtRegisteredClaimNames.Nbf, StringComparison.Ordinal)) + { + nbfChecked = true; + if (nbfSet) + { + if (LogHelper.IsEnabled(EventLogLevel.Informational)) + LogHelper.LogInformation(LogHelper.FormatInvariant(LogMessages.IDX14113, LogHelper.MarkAsNonPII(nameof(tokenDescriptor.Expires)))); + + continue; + } + + nbfSet = true; + } + + JsonPrimitives.WriteObject(ref writer, kvp.Key, kvp.Value); + } + } + + AddSubjectClaims(ref writer, tokenDescriptor, audienceSet, issuerSet, ref expSet, ref iatSet, ref nbfSet); + + // By default we set these three properties only if they haven't been detected before. + if (setDefaultTimesOnTokenCreation && !(expSet && iatSet && nbfSet)) + { + DateTime now = DateTime.UtcNow; + + if (!expSet) + { + writer.WritePropertyName(JwtPayloadUtf8Bytes.Exp); + writer.WriteNumberValue(EpochTime.GetIntDate(now + TimeSpan.FromMinutes(tokenLifetimeInMinutes))); + } + + if (!iatSet) + { + writer.WritePropertyName(JwtPayloadUtf8Bytes.Iat); + writer.WriteNumberValue(EpochTime.GetIntDate(now)); + } + + if (!nbfSet) + { + writer.WritePropertyName(JwtPayloadUtf8Bytes.Nbf); + writer.WriteNumberValue(EpochTime.GetIntDate(now)); + } + } + + writer.WriteEndObject(); + writer.Flush(); + } + + internal static void AddSubjectClaims( + ref Utf8JsonWriter writer, + SecurityTokenDescriptor tokenDescriptor, + bool audienceSet, + bool issuerSet, + ref bool expSet, + ref bool iatSet, + ref bool nbfSet) + { + if (tokenDescriptor.Subject == null) + return; + + bool expReset = false; + bool iatReset = false; + bool nbfReset = false; + + var payload = new Dictionary(); + + bool checkClaims = tokenDescriptor.Claims != null && tokenDescriptor.Claims.Count > 0; + + foreach (Claim claim in tokenDescriptor.Subject.Claims) + { + if (claim == null) + continue; + + // skipping these as they have been added by values in the SecurityTokenDescriptor + if (checkClaims && tokenDescriptor.Claims.ContainsKey(claim.Type)) + continue; + + if (audienceSet && claim.Type.Equals(JwtRegisteredClaimNames.Aud, StringComparison.Ordinal)) + continue; + + if (issuerSet && claim.Type.Equals(JwtRegisteredClaimNames.Iss, StringComparison.Ordinal)) + continue; + + if (claim.Type.Equals(JwtRegisteredClaimNames.Exp, StringComparison.Ordinal)) + { + if (expSet) + continue; + + expReset = true; + } + + if (claim.Type.Equals(JwtRegisteredClaimNames.Iat, StringComparison.Ordinal)) + { + if (iatSet) + continue; + + iatReset = true; + } + + if (claim.Type.Equals(JwtRegisteredClaimNames.Nbf, StringComparison.Ordinal)) + { + if (nbfSet) + continue; + + nbfReset = true; + } + + object jsonClaimValue = claim.ValueType.Equals(ClaimValueTypes.String) ? claim.Value : TokenUtilities.GetClaimValueUsingValueType(claim); + + // The enumeration is from ClaimsIdentity.Claims, there can be duplicates. + // When a duplicate is detected, we create a List and add both to a list. + // When the creating the JWT and a list is found, a JsonArray will be created. + if (payload.TryGetValue(claim.Type, out object existingValue)) + { + if (existingValue is List existingList) + { + existingList.Add(jsonClaimValue); + } + else + { + payload[claim.Type] = new List + { + existingValue, + jsonClaimValue + }; + } + } + else + { + payload[claim.Type] = jsonClaimValue; + } + } + + foreach (KeyValuePair kvp in payload) + JsonPrimitives.WriteObject(ref writer, kvp.Key, kvp.Value); + + expSet |= expReset; + iatSet |= iatReset; + nbfSet |= nbfReset; + } + + internal static void WriteJwsHeader( + ref Utf8JsonWriter writer, + SigningCredentials signingCredentials, + EncryptingCredentials encryptingCredentials, + IDictionary jweHeaderClaims, + IDictionary jwsHeaderClaims, + string tokenType) + { + if (jweHeaderClaims?.Count > 0 && jweHeaderClaims.Keys.Intersect(JwtTokenUtilities.DefaultHeaderParameters, StringComparer.OrdinalIgnoreCase).Any()) + throw LogHelper.LogExceptionMessage( + new SecurityTokenException( + LogHelper.FormatInvariant( + LogMessages.IDX14116, + LogHelper.MarkAsNonPII(nameof(jweHeaderClaims)), + LogHelper.MarkAsNonPII(string.Join(", ", JwtTokenUtilities.DefaultHeaderParameters))))); + + if (jwsHeaderClaims?.Count > 0 && jwsHeaderClaims.Keys.Intersect(JwtTokenUtilities.DefaultHeaderParameters, StringComparer.OrdinalIgnoreCase).Any()) + throw LogHelper.LogExceptionMessage( + new SecurityTokenException( + LogHelper.FormatInvariant( + LogMessages.IDX14116, + LogHelper.MarkAsNonPII(nameof(jwsHeaderClaims)), + LogHelper.MarkAsNonPII(string.Join(", ", JwtTokenUtilities.DefaultHeaderParameters))))); + + + // If token is a JWE, jweHeaderClaims go in outer header. + bool addJweHeaderClaims = encryptingCredentials is null && jweHeaderClaims?.Count > 0; + bool addJwsHeaderClaims = jwsHeaderClaims?.Count > 0; + bool typeWritten = false; + writer.WriteStartObject(); + + if (signingCredentials == null) + { + writer.WriteString(JwtHeaderUtf8Bytes.Alg, SecurityAlgorithms.None); + } + else + { + writer.WriteString(JwtHeaderUtf8Bytes.Alg, signingCredentials.Algorithm); + if (signingCredentials.Key.KeyId != null) + writer.WriteString(JwtHeaderUtf8Bytes.Kid, signingCredentials.Key.KeyId); + + if (signingCredentials.Key is X509SecurityKey x509SecurityKey) + writer.WriteString(JwtHeaderUtf8Bytes.X5t, x509SecurityKey.X5t); + } + + // Priority is additionalInnerHeaderClaims, additionalHeaderClaims, defaults + if (addJweHeaderClaims) + { + foreach (KeyValuePair kvp in jweHeaderClaims) + { + if (addJwsHeaderClaims && jwsHeaderClaims.ContainsKey(kvp.Key)) + continue; + + JsonPrimitives.WriteObject(ref writer, kvp.Key, kvp.Value); + if (!typeWritten && kvp.Key.Equals(JwtHeaderParameterNames.Typ, StringComparison.Ordinal)) + typeWritten = true; + } + } + + if (addJwsHeaderClaims) + { + foreach (KeyValuePair kvp in jwsHeaderClaims) + { + JsonPrimitives.WriteObject(ref writer, kvp.Key, kvp.Value); + if (!typeWritten && kvp.Key.Equals(JwtHeaderParameterNames.Typ, StringComparison.Ordinal)) + typeWritten = true; + } + } + + if (!typeWritten) + writer.WriteString(JwtHeaderUtf8Bytes.Typ, string.IsNullOrEmpty(tokenType) ? JwtConstants.HeaderType : tokenType); + + writer.WriteEndObject(); + writer.Flush(); + } + + internal static byte[] WriteJweHeader( + EncryptingCredentials encryptingCredentials, + string compressionAlgorithm, + string tokenType, + IDictionary jweHeaderClaims) + { + using (MemoryStream memoryStream = new()) + { + Utf8JsonWriter writer = null; + try + { + writer = new Utf8JsonWriter(memoryStream, new JsonWriterOptions { Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping }); + writer.WriteStartObject(); + + writer.WriteString(JwtHeaderUtf8Bytes.Alg, encryptingCredentials.Alg); + writer.WriteString(JwtHeaderUtf8Bytes.Enc, encryptingCredentials.Enc); + + if (encryptingCredentials.Key.KeyId != null) + writer.WriteString(JwtHeaderUtf8Bytes.Kid, encryptingCredentials.Key.KeyId); + + if (!string.IsNullOrEmpty(compressionAlgorithm)) + writer.WriteString(JwtHeaderUtf8Bytes.Zip, compressionAlgorithm); + + bool typeWritten = false; + bool ctyWritten = !encryptingCredentials.SetDefaultCtyClaim; + + // Current 6x Priority is jweHeaderClaims, type, cty + if (jweHeaderClaims != null && jweHeaderClaims.Count > 0) + { + foreach (KeyValuePair kvp in jweHeaderClaims) + { + JsonPrimitives.WriteObject(ref writer, kvp.Key, kvp.Value); + if (!typeWritten && kvp.Key.Equals(JwtHeaderParameterNames.Typ, StringComparison.Ordinal)) + typeWritten = true; + else if (!ctyWritten && kvp.Key.Equals(JwtHeaderParameterNames.Cty, StringComparison.Ordinal)) + ctyWritten = true; + } + } + + if (!typeWritten) + writer.WriteString(JwtHeaderUtf8Bytes.Typ, string.IsNullOrEmpty(tokenType) ? JwtConstants.HeaderType : tokenType); + + if (!ctyWritten) + writer.WriteString(JwtHeaderUtf8Bytes.Cty, JwtConstants.HeaderType); + + writer.WriteEndObject(); + writer.Flush(); + + return memoryStream.ToArray(); + } + finally + { + writer?.Dispose(); + } + } + } + + internal static byte[] CompressToken(byte[] utf8Bytes, string compressionAlgorithm) + { + if (string.IsNullOrEmpty(compressionAlgorithm)) + throw LogHelper.LogArgumentNullException(nameof(compressionAlgorithm)); + + if (!CompressionProviderFactory.Default.IsSupportedAlgorithm(compressionAlgorithm)) + throw LogHelper.LogExceptionMessage(new NotSupportedException(LogHelper.FormatInvariant(TokenLogMessages.IDX10682, LogHelper.MarkAsNonPII(compressionAlgorithm)))); + + var compressionProvider = CompressionProviderFactory.Default.CreateCompressionProvider(compressionAlgorithm); + + return compressionProvider.Compress(utf8Bytes) ?? throw LogHelper.LogExceptionMessage(new InvalidOperationException(LogHelper.FormatInvariant(TokenLogMessages.IDX10680, LogHelper.MarkAsNonPII(compressionAlgorithm)))); + } + + /// + /// Encrypts a JWS. + /// + /// A 'JSON Web Token' (JWT) in JWS Compact Serialization Format. + /// Defines the security key and algorithm that will be used to encrypt the . + /// if is null or empty. + /// if is null. + /// if both and . are null. + /// if the CryptoProviderFactory being used does not support the (algorithm), pair. + /// if unable to create a token encryption provider for the (algorithm), pair. + /// if encryption fails using the (algorithm), pair. + /// if not using one of the supported content encryption key (CEK) algorithms: 128, 384 or 512 AesCbcHmac (this applies in the case of key wrap only, not direct encryption). + public string EncryptToken(string innerJwt, EncryptingCredentials encryptingCredentials) + { + if (string.IsNullOrEmpty(innerJwt)) + throw LogHelper.LogArgumentNullException(nameof(innerJwt)); + + if (encryptingCredentials == null) + throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + + return EncryptTokenPrivate(innerJwt, encryptingCredentials, null, null, null); + } + + /// + /// Encrypts a JWS. + /// + /// A 'JSON Web Token' (JWT) in JWS Compact Serialization Format. + /// Defines the security key and algorithm that will be used to encrypt the . + /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. + /// if is null or empty. + /// if is null. + /// if is null. + /// if both and . are null. + /// if the CryptoProviderFactory being used does not support the (algorithm), pair. + /// if unable to create a token encryption provider for the (algorithm), pair. + /// if encryption fails using the (algorithm), pair. + /// if not using one of the supported content encryption key (CEK) algorithms: 128, 384 or 512 AesCbcHmac (this applies in the case of key wrap only, not direct encryption). + public string EncryptToken( + string innerJwt, + EncryptingCredentials encryptingCredentials, + IDictionary additionalHeaderClaims) + { + if (string.IsNullOrEmpty(innerJwt)) + throw LogHelper.LogArgumentNullException(nameof(innerJwt)); + + if (encryptingCredentials == null) + throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + + if (additionalHeaderClaims == null) + throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); + + return EncryptTokenPrivate(innerJwt, encryptingCredentials, null, additionalHeaderClaims, null); + } + + /// + /// Encrypts a JWS. + /// + /// A 'JSON Web Token' (JWT) in JWS Compact Serialization Format. + /// Defines the security key and algorithm that will be used to encrypt the . + /// Defines the compression algorithm that will be used to compress the 'innerJwt'. + /// if is null or empty. + /// if is null. + /// if is null or empty. + /// if both and . are null. + /// if the CryptoProviderFactory being used does not support the (algorithm), pair. + /// if unable to create a token encryption provider for the (algorithm), pair. + /// if compression using fails. + /// if encryption fails using the (algorithm), pair. + /// if not using one of the supported content encryption key (CEK) algorithms: 128, 384 or 512 AesCbcHmac (this applies in the case of key wrap only, not direct encryption). + public string EncryptToken( + string innerJwt, + EncryptingCredentials encryptingCredentials, + string algorithm) + { + if (string.IsNullOrEmpty(innerJwt)) + throw LogHelper.LogArgumentNullException(nameof(innerJwt)); + + if (encryptingCredentials == null) + throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + + if (string.IsNullOrEmpty(algorithm)) + throw LogHelper.LogArgumentNullException(nameof(algorithm)); + + return EncryptTokenPrivate(innerJwt, encryptingCredentials, algorithm, null, null); + } + + /// + /// Encrypts a JWS. + /// + /// A 'JSON Web Token' (JWT) in JWS Compact Serialization Format. + /// Defines the security key and algorithm that will be used to encrypt the . + /// Defines the compression algorithm that will be used to compress the + /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. + /// if is null or empty. + /// if is null. + /// if is null or empty. + /// if is null or empty. + /// if both and . are null. + /// if the CryptoProviderFactory being used does not support the (algorithm), pair. + /// if unable to create a token encryption provider for the (algorithm), pair. + /// if compression using 'algorithm' fails. + /// if encryption fails using the (algorithm), pair. + /// if not using one of the supported content encryption key (CEK) algorithms: 128, 384 or 512 AesCbcHmac (this applies in the case of key wrap only, not direct encryption). + public string EncryptToken( + string innerJwt, + EncryptingCredentials encryptingCredentials, + string algorithm, + IDictionary additionalHeaderClaims) + { + if (string.IsNullOrEmpty(innerJwt)) + throw LogHelper.LogArgumentNullException(nameof(innerJwt)); + + if (encryptingCredentials == null) + throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); + + if (string.IsNullOrEmpty(algorithm)) + throw LogHelper.LogArgumentNullException(nameof(algorithm)); + + if (additionalHeaderClaims == null) + throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); + + return EncryptTokenPrivate(innerJwt, encryptingCredentials, algorithm, additionalHeaderClaims, null); + } + + private static string EncryptTokenPrivate( + string innerJwt, + EncryptingCredentials encryptingCredentials, + string compressionAlgorithm, + IDictionary additionalHeaderClaims, + string tokenType) + { + return (EncryptToken( + Encoding.UTF8.GetBytes(innerJwt), + encryptingCredentials, + compressionAlgorithm, + additionalHeaderClaims, + tokenType)); + } + + internal static string EncryptToken( + byte[] innerTokenUtf8Bytes, + EncryptingCredentials encryptingCredentials, + string compressionAlgorithm, + IDictionary additionalHeaderClaims, + string tokenType) + { + CryptoProviderFactory cryptoProviderFactory = encryptingCredentials.CryptoProviderFactory ?? encryptingCredentials.Key.CryptoProviderFactory; + + if (cryptoProviderFactory == null) + throw LogHelper.LogExceptionMessage(new ArgumentException(TokenLogMessages.IDX10620)); + + SecurityKey securityKey = JwtTokenUtilities.GetSecurityKey(encryptingCredentials, cryptoProviderFactory, additionalHeaderClaims, out byte[] wrappedKey); + + using (AuthenticatedEncryptionProvider encryptionProvider = cryptoProviderFactory.CreateAuthenticatedEncryptionProvider(securityKey, encryptingCredentials.Enc)) + { + if (encryptionProvider == null) + throw LogHelper.LogExceptionMessage(new SecurityTokenEncryptionFailedException(LogMessages.IDX14103)); + + byte[] jweHeader = WriteJweHeader(encryptingCredentials, compressionAlgorithm, tokenType, additionalHeaderClaims); + byte[] plainText; + if (!string.IsNullOrEmpty(compressionAlgorithm)) + { + try + { + plainText = CompressToken(innerTokenUtf8Bytes, compressionAlgorithm); + } + catch (Exception ex) + { + throw LogHelper.LogExceptionMessage(new SecurityTokenCompressionFailedException(LogHelper.FormatInvariant(TokenLogMessages.IDX10680, LogHelper.MarkAsNonPII(compressionAlgorithm)), ex)); + } + } + else + { + plainText = innerTokenUtf8Bytes; + } + + try + { + string rawHeader = Base64UrlEncoder.Encode(jweHeader); + + var encryptionResult = encryptionProvider.Encrypt(plainText, Encoding.ASCII.GetBytes(rawHeader)); + return JwtConstants.DirectKeyUseAlg.Equals(encryptingCredentials.Alg) ? + string.Join(".", rawHeader, string.Empty, Base64UrlEncoder.Encode(encryptionResult.IV), Base64UrlEncoder.Encode(encryptionResult.Ciphertext), Base64UrlEncoder.Encode(encryptionResult.AuthenticationTag)) : + string.Join(".", rawHeader, Base64UrlEncoder.Encode(wrappedKey), Base64UrlEncoder.Encode(encryptionResult.IV), Base64UrlEncoder.Encode(encryptionResult.Ciphertext), Base64UrlEncoder.Encode(encryptionResult.AuthenticationTag)); + } + catch (Exception ex) + { + throw LogHelper.LogExceptionMessage(new SecurityTokenEncryptionFailedException(LogHelper.FormatInvariant(TokenLogMessages.IDX10616, LogHelper.MarkAsNonPII(encryptingCredentials.Enc), encryptingCredentials.Key), ex)); + } + } + } + + internal IEnumerable GetContentEncryptionKeys(JsonWebToken jwtToken, TokenValidationParameters validationParameters, BaseConfiguration configuration) + { + IEnumerable keys = null; + + // First we check to see if the caller has set a custom decryption resolver on TVP for the call, if so any keys set on TVP and keys in Configuration are ignored. + // If no custom decryption resolver set, we'll check to see if they've set some static decryption keys on TVP. If a key found, we ignore configuration. + // If no key found in TVP, we'll check the configuration. + if (validationParameters.TokenDecryptionKeyResolver != null) + { + keys = validationParameters.TokenDecryptionKeyResolver(jwtToken.EncodedToken, jwtToken, jwtToken.Kid, validationParameters); + } + else + { + var key = ResolveTokenDecryptionKey(jwtToken.EncodedToken, jwtToken, validationParameters); + if (key != null) + { + if (LogHelper.IsEnabled(EventLogLevel.Informational)) + LogHelper.LogInformation(TokenLogMessages.IDX10904, key); + } + else if (configuration != null) + { + key = ResolveTokenDecryptionKeyFromConfig(jwtToken, configuration); + if (key != null && LogHelper.IsEnabled(EventLogLevel.Informational)) + LogHelper.LogInformation(TokenLogMessages.IDX10905, key); + } + + if (key != null) + keys = new List { key }; + } + + // on decryption for ECDH-ES, we get the public key from the EPK value see: https://datatracker.ietf.org/doc/html/rfc7518#appendix-C + // we need the ECDSASecurityKey for the receiver, use TokenValidationParameters.TokenDecryptionKey + + // control gets here if: + // 1. User specified delegate: TokenDecryptionKeyResolver returned null + // 2. ResolveTokenDecryptionKey returned null + // 3. ResolveTokenDecryptionKeyFromConfig returned null + // Try all the keys. This is the degenerate case, not concerned about perf. + if (keys == null) + { + keys = JwtTokenUtilities.GetAllDecryptionKeys(validationParameters); + if (configuration != null) + keys = keys == null ? configuration.TokenDecryptionKeys : keys.Concat(configuration.TokenDecryptionKeys); + } + + if (jwtToken.Alg.Equals(JwtConstants.DirectKeyUseAlg, StringComparison.Ordinal) + || jwtToken.Alg.Equals(SecurityAlgorithms.EcdhEs, StringComparison.Ordinal)) + return keys; + + var unwrappedKeys = new List(); + // keep track of exceptions thrown, keys that were tried + StringBuilder exceptionStrings = null; + StringBuilder keysAttempted = null; + foreach (var key in keys) + { + try + { +#if NET472 || NET6_0_OR_GREATER + if (SupportedAlgorithms.EcdsaWrapAlgorithms.Contains(jwtToken.Alg)) + { + // on decryption we get the public key from the EPK value see: https://datatracker.ietf.org/doc/html/rfc7518#appendix-C + var ecdhKeyExchangeProvider = new EcdhKeyExchangeProvider( + key as ECDsaSecurityKey, + validationParameters.TokenDecryptionKey as ECDsaSecurityKey, + jwtToken.Alg, + jwtToken.Enc); + jwtToken.TryGetHeaderValue(JwtHeaderParameterNames.Apu, out string apu); + jwtToken.TryGetHeaderValue(JwtHeaderParameterNames.Apv, out string apv); + SecurityKey kdf = ecdhKeyExchangeProvider.GenerateKdf(apu, apv); + var kwp = key.CryptoProviderFactory.CreateKeyWrapProviderForUnwrap(kdf, ecdhKeyExchangeProvider.GetEncryptionAlgorithm()); + var unwrappedKey = kwp.UnwrapKey(Base64UrlEncoder.DecodeBytes(jwtToken.EncryptedKey)); + unwrappedKeys.Add(new SymmetricSecurityKey(unwrappedKey)); + } + else +#endif + if (key.CryptoProviderFactory.IsSupportedAlgorithm(jwtToken.Alg, key)) + { + var kwp = key.CryptoProviderFactory.CreateKeyWrapProviderForUnwrap(key, jwtToken.Alg); + var unwrappedKey = kwp.UnwrapKey(jwtToken.EncryptedKeyBytes); + unwrappedKeys.Add(new SymmetricSecurityKey(unwrappedKey)); + } + } + catch (Exception ex) + { + (exceptionStrings ??= new StringBuilder()).AppendLine(ex.ToString()); + } + + (keysAttempted ??= new StringBuilder()).AppendLine(key.ToString()); + } + + if (unwrappedKeys.Count > 0 && exceptionStrings is null) + return unwrappedKeys; + else + throw LogHelper.LogExceptionMessage(new SecurityTokenKeyWrapException(LogHelper.FormatInvariant(TokenLogMessages.IDX10618, (object)keysAttempted ?? "", (object)exceptionStrings ?? "", jwtToken))); + } + } +} diff --git a/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebTokenHandler.cs b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebTokenHandler.cs index f5c90ddcd6..781c0d5d3e 100644 --- a/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebTokenHandler.cs +++ b/src/Microsoft.IdentityModel.JsonWebTokens/JsonWebTokenHandler.cs @@ -3,19 +3,15 @@ using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Security.Claims; using System.Text; -using System.Text.Encodings.Web; -using System.Text.Json; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Microsoft.IdentityModel.Abstractions; using Microsoft.IdentityModel.Logging; using Microsoft.IdentityModel.Tokens; -using JsonPrimitives = Microsoft.IdentityModel.Tokens.Json.JsonSerializerPrimitives; using TokenLogMessages = Microsoft.IdentityModel.Tokens.LogMessages; namespace Microsoft.IdentityModel.JsonWebTokens @@ -24,7 +20,7 @@ namespace Microsoft.IdentityModel.JsonWebTokens /// A designed for creating and validating Json Web Tokens. /// See: https://datatracker.ietf.org/doc/html/rfc7519 and http://www.rfc-editor.org/info/rfc7515. /// - public class JsonWebTokenHandler : TokenHandler + public partial class JsonWebTokenHandler : TokenHandler { private IDictionary _inboundClaimTypeMap; private const string _namespace = "http://schemas.xmlsoap.org/ws/2005/05/identity/claimproperties"; @@ -101,7 +97,7 @@ public bool MapInboundClaims } set { - if(!_mapInboundClaims && value && _inboundClaimTypeMap.Count == 0) + if (!_mapInboundClaims && value && _inboundClaimTypeMap.Count == 0) _inboundClaimTypeMap = new Dictionary(DefaultInboundClaimTypeMap); _mapInboundClaims = value; } @@ -187,674 +183,6 @@ public virtual bool CanValidateToken get { return true; } } - /// - /// Creates an unsigned JWS (Json Web Signature). - /// - /// A string containing JSON which represents the JWT token payload. - /// if is null. - /// A JWS in Compact Serialization Format. - public virtual string CreateToken(string payload) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - return CreateToken(Encoding.UTF8.GetBytes(payload), null, null, null, null, null, null); - } - - /// - /// Creates an unsigned JWS (Json Web Signature). - /// - /// A string containing JSON which represents the JWT token payload. - /// Defines the dictionary containing any custom header claims that need to be added to the JWT token header. - /// if is null. - /// if is null. - /// A JWS in Compact Serialization Format. - public virtual string CreateToken(string payload, IDictionary additionalHeaderClaims) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - if (additionalHeaderClaims == null) - throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); - - return CreateToken(Encoding.UTF8.GetBytes(payload), null, null, null, additionalHeaderClaims, null, null); - } - - /// - /// Creates a JWS (Json Web Signature). - /// - /// A string containing JSON which represents the JWT token payload. - /// Defines the security key and algorithm that will be used to sign the JWS. - /// if is null. - /// if is null. - /// A JWS in Compact Serialization Format. - public virtual string CreateToken(string payload, SigningCredentials signingCredentials) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - if (signingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); - - return CreateToken(Encoding.UTF8.GetBytes(payload), signingCredentials, null, null, null, null, null); - } - - /// - /// Creates a JWS (Json Web Signature). - /// - /// A string containing JSON which represents the JWT token payload. - /// Defines the security key and algorithm that will be used to sign the JWS. - /// Defines the dictionary containing any custom header claims that need to be added to the JWT token header. - /// if is null. - /// if is null. - /// if is null. - /// if , - /// , , and/or - /// are present inside of . - /// A JWS in Compact Serialization Format. - public virtual string CreateToken(string payload, SigningCredentials signingCredentials, IDictionary additionalHeaderClaims) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - if (signingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); - - if (additionalHeaderClaims == null) - throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); - - return CreateToken(Encoding.UTF8.GetBytes(payload), signingCredentials, null, null, additionalHeaderClaims, null, null); - } - - /// - /// Creates a JWS(Json Web Signature). - /// - /// A that contains details of contents of the token. - /// A JWS in Compact Serialization Format. - public virtual string CreateToken(SecurityTokenDescriptor tokenDescriptor) - { - _ = tokenDescriptor ?? throw LogHelper.LogArgumentNullException(nameof(tokenDescriptor)); - - if (LogHelper.IsEnabled(EventLogLevel.Warning)) - { - if ((tokenDescriptor.Subject == null || !tokenDescriptor.Subject.Claims.Any()) - && (tokenDescriptor.Claims == null || !tokenDescriptor.Claims.Any())) - LogHelper.LogWarning( - LogMessages.IDX14114, LogHelper.MarkAsNonPII(nameof(SecurityTokenDescriptor)), LogHelper.MarkAsNonPII(nameof(SecurityTokenDescriptor.Subject)), LogHelper.MarkAsNonPII(nameof(SecurityTokenDescriptor.Claims))); - } - - return CreateToken( - WritePayload(tokenDescriptor), - tokenDescriptor.SigningCredentials, - tokenDescriptor.EncryptingCredentials, - tokenDescriptor.CompressionAlgorithm, - tokenDescriptor.AdditionalHeaderClaims, - tokenDescriptor.AdditionalInnerHeaderClaims, - tokenDescriptor.TokenType); - } - - internal static byte[] WriteJwsHeader( - SigningCredentials signingCredentials, - string tokenType, - IDictionary jwsHeaderClaims, - IDictionary jweHeaderClaims) - { - using (MemoryStream memoryStream = new MemoryStream()) - { - Utf8JsonWriter writer = null; - try - { - writer = new Utf8JsonWriter(memoryStream, new JsonWriterOptions { Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping }); - writer.WriteStartObject(); - - if (signingCredentials == null) - { - writer.WriteString(JwtHeaderUtf8Bytes.Alg, SecurityAlgorithms.None); - } - else - { - writer.WriteString(JwtHeaderUtf8Bytes.Alg, signingCredentials.Algorithm); - if (signingCredentials.Key.KeyId != null) - writer.WriteString(JwtHeaderUtf8Bytes.Kid, signingCredentials.Key.KeyId); - - if (signingCredentials.Key is X509SecurityKey x509SecurityKey) - writer.WriteString(JwtHeaderUtf8Bytes.X5t, x509SecurityKey.X5t); - } - - bool useJwsHeaderClaims = jwsHeaderClaims != null && jwsHeaderClaims.Count > 0; - bool typeWritten = false; - - // Priority is jwsHeaderClaims, jweHeaderClaims, default - if (jweHeaderClaims != null && jweHeaderClaims.Count > 0) - { - foreach (KeyValuePair kvp in jweHeaderClaims) - { - if (useJwsHeaderClaims && jwsHeaderClaims.ContainsKey(kvp.Key)) - continue; - - JsonPrimitives.WriteObject(ref writer, kvp.Key, kvp.Value); - if (!typeWritten && kvp.Key.Equals(JwtHeaderParameterNames.Typ, StringComparison.Ordinal)) - typeWritten = true; - } - } - - if (useJwsHeaderClaims) - { - foreach (KeyValuePair kvp in jwsHeaderClaims) - { - JsonPrimitives.WriteObject(ref writer, kvp.Key, kvp.Value); - if (!typeWritten && kvp.Key.Equals(JwtHeaderParameterNames.Typ, StringComparison.Ordinal)) - typeWritten = true; - } - } - - if (!typeWritten) - writer.WriteString(JwtHeaderUtf8Bytes.Typ, string.IsNullOrEmpty(tokenType) ? JwtConstants.HeaderType : tokenType); - - writer.WriteEndObject(); - writer.Flush(); - - return memoryStream.ToArray(); - } - finally - { - writer?.Dispose(); - } - } - } - - internal static byte[] WriteJweHeader( - EncryptingCredentials encryptingCredentials, - string compressionAlgorithm, - string tokenType, - IDictionary jweHeaderClaims) - { - using (MemoryStream memoryStream = new MemoryStream()) - { - Utf8JsonWriter writer = null; - try - { - writer = new Utf8JsonWriter(memoryStream, new JsonWriterOptions { Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping }); - writer.WriteStartObject(); - - writer.WriteString(JwtHeaderUtf8Bytes.Alg, encryptingCredentials.Alg); - writer.WriteString(JwtHeaderUtf8Bytes.Enc, encryptingCredentials.Enc); - - if (encryptingCredentials.Key.KeyId != null) - writer.WriteString(JwtHeaderUtf8Bytes.Kid, encryptingCredentials.Key.KeyId); - - if (!string.IsNullOrEmpty(compressionAlgorithm)) - writer.WriteString(JwtHeaderUtf8Bytes.Zip, compressionAlgorithm); - - bool typeWritten = false; - bool ctyWritten = !encryptingCredentials.SetDefaultCtyClaim; - - // Current 6x Priority is jweHeaderClaims, type, cty - if (jweHeaderClaims != null && jweHeaderClaims.Count > 0) - { - foreach (KeyValuePair kvp in jweHeaderClaims) - { - JsonPrimitives.WriteObject(ref writer, kvp.Key, kvp.Value); - if (!typeWritten && kvp.Key.Equals(JwtHeaderParameterNames.Typ, StringComparison.Ordinal)) - typeWritten = true; - else if (!ctyWritten && kvp.Key.Equals(JwtHeaderParameterNames.Cty, StringComparison.Ordinal)) - ctyWritten = true; - } - } - - if (!typeWritten) - writer.WriteString(JwtHeaderUtf8Bytes.Typ, string.IsNullOrEmpty(tokenType) ? JwtConstants.HeaderType : tokenType); - - if (!ctyWritten) - writer.WriteString(JwtHeaderUtf8Bytes.Cty, JwtConstants.HeaderType); - - writer.WriteEndObject(); - writer.Flush(); - - return memoryStream.ToArray(); - } - finally - { - writer?.Dispose(); - } - } - } - - internal byte[] WritePayload(SecurityTokenDescriptor tokenDescriptor) - { - bool audienceSet = !string.IsNullOrEmpty(tokenDescriptor.Audience); - bool issuerSet = !string.IsNullOrEmpty(tokenDescriptor.Issuer); - - IDictionary payload = TokenUtilities.CreateDictionaryFromClaims(tokenDescriptor.Subject?.Claims, tokenDescriptor, audienceSet, issuerSet); - - // Duplicates are resolved according to the following priority: - // SecurityTokenDescriptor.{Audience, Issuer, Expires, IssuedAt, NotBefore}, SecurityTokenDescriptor.Claims, SecurityTokenDescriptor.Subject.Claims - - if (tokenDescriptor.Claims != null && tokenDescriptor.Claims.Count > 0) - { - foreach (var kvp in tokenDescriptor.Claims) - { - if (audienceSet && kvp.Key.Equals("aud", StringComparison.Ordinal)) - continue; - - if (issuerSet && kvp.Key.Equals("iss", StringComparison.Ordinal)) - continue; - - if (tokenDescriptor.Expires.HasValue && kvp.Key.Equals("exp", StringComparison.Ordinal)) - continue; - - if (tokenDescriptor.IssuedAt.HasValue && kvp.Key.Equals("iat", StringComparison.Ordinal)) - continue; - - if (tokenDescriptor.NotBefore.HasValue && kvp.Key.Equals("nbf", StringComparison.Ordinal)) - continue; - - payload[kvp.Key] = kvp.Value; - } - } - - bool expiresSet = false; - bool nbfSet = false; - bool iatSet = false; - - if (audienceSet) - { - if (LogHelper.IsEnabled(EventLogLevel.Informational) && payload.ContainsKey(JwtRegisteredClaimNames.Aud)) - LogHelper.LogInformation(LogHelper.FormatInvariant(LogMessages.IDX14113, LogHelper.MarkAsNonPII(nameof(tokenDescriptor.Audience)))); - - payload[JwtRegisteredClaimNames.Aud] = tokenDescriptor.Audience; - } - - if (issuerSet) - { - if (LogHelper.IsEnabled(EventLogLevel.Informational) && payload.ContainsKey(JwtRegisteredClaimNames.Iss)) - LogHelper.LogInformation(LogHelper.FormatInvariant(LogMessages.IDX14113, LogHelper.MarkAsNonPII(nameof(tokenDescriptor.Issuer)))); - - payload[JwtRegisteredClaimNames.Iss] = tokenDescriptor.Issuer; - } - - if (tokenDescriptor.Expires.HasValue) - { - if (LogHelper.IsEnabled(EventLogLevel.Informational) && payload.ContainsKey(JwtRegisteredClaimNames.Exp)) - LogHelper.LogInformation(LogHelper.FormatInvariant(LogMessages.IDX14113, LogHelper.MarkAsNonPII(nameof(tokenDescriptor.Expires)))); - - payload[JwtRegisteredClaimNames.Exp] = EpochTime.GetIntDate(tokenDescriptor.Expires.Value); - expiresSet = true; - } - - if (tokenDescriptor.IssuedAt.HasValue) - { - if (LogHelper.IsEnabled(EventLogLevel.Informational) && payload.ContainsKey(JwtRegisteredClaimNames.Iat)) - LogHelper.LogInformation(LogHelper.FormatInvariant(LogMessages.IDX14113, LogHelper.MarkAsNonPII(nameof(tokenDescriptor.IssuedAt)))); - - payload[JwtRegisteredClaimNames.Iat] = EpochTime.GetIntDate(tokenDescriptor.IssuedAt.Value); - iatSet = true; - } - - if (tokenDescriptor.NotBefore.HasValue) - { - if (LogHelper.IsEnabled(EventLogLevel.Informational) && payload.ContainsKey(JwtRegisteredClaimNames.Nbf)) - LogHelper.LogInformation(LogHelper.FormatInvariant(LogMessages.IDX14113, LogHelper.MarkAsNonPII(nameof(tokenDescriptor.NotBefore)))); - - payload[JwtRegisteredClaimNames.Nbf] = EpochTime.GetIntDate(tokenDescriptor.NotBefore.Value); - nbfSet = true; - } - - // by default we set these three properties only if they haven't been set. - if (SetDefaultTimesOnTokenCreation) - { - long now = EpochTime.GetIntDate(DateTime.UtcNow); - - if (!expiresSet && !payload.ContainsKey(JwtRegisteredClaimNames.Exp)) - payload.Add(JwtRegisteredClaimNames.Exp, now + TokenLifetimeInMinutes * 60); - - if (!iatSet && !payload.ContainsKey(JwtRegisteredClaimNames.Iat)) - payload.Add(JwtRegisteredClaimNames.Iat, now); - - if (!nbfSet && !payload.ContainsKey(JwtRegisteredClaimNames.Nbf)) - payload.Add(JwtRegisteredClaimNames.Nbf, now); - } - - using (MemoryStream memoryStream = new()) - { - Utf8JsonWriter writer = null; - try - { - writer = new Utf8JsonWriter(memoryStream, new JsonWriterOptions { Encoder = JavaScriptEncoder.UnsafeRelaxedJsonEscaping }); - writer.WriteStartObject(); - - foreach (KeyValuePair kvp in payload) - { - if (kvp.Value is IList l) - { - writer.WriteStartArray(kvp.Key); - - foreach (object obj in l) - JsonPrimitives.WriteObjectValue(ref writer, obj); - - writer.WriteEndArray(); - } - else - { - JsonPrimitives.WriteObject(ref writer, kvp.Key, kvp.Value); - } - } - - writer.WriteEndObject(); - writer.Flush(); - - return memoryStream.ToArray(); - } - finally - { - writer?.Dispose(); - } - } - } - - /// - /// Creates a JWE (Json Web Encryption). - /// - /// A string containing JSON which represents the JWT token payload. - /// Defines the security key and algorithm that will be used to encrypt the JWT. - /// A JWE in compact serialization format. - public virtual string CreateToken(string payload, EncryptingCredentials encryptingCredentials) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - return CreateToken(Encoding.UTF8.GetBytes(payload), null, encryptingCredentials, null, null, null, null); - } - - /// - /// Creates a JWE (Json Web Encryption). - /// - /// A string containing JSON which represents the JWT token payload. - /// Defines the security key and algorithm that will be used to encrypt the JWT. - /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. - /// if is null. - /// if is null. - /// if is null. - /// if , - /// , , and/or - /// are present inside of . - /// A JWS in Compact Serialization Format. - public virtual string CreateToken(string payload, EncryptingCredentials encryptingCredentials, IDictionary additionalHeaderClaims) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - if (additionalHeaderClaims == null) - throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); - - return CreateToken(Encoding.UTF8.GetBytes(payload), null, encryptingCredentials, null, additionalHeaderClaims, null, null); - } - - /// - /// Creates a JWE (Json Web Encryption). - /// - /// A string containing JSON which represents the JWT token payload. - /// Defines the security key and algorithm that will be used to sign the JWT. - /// Defines the security key and algorithm that will be used to encrypt the JWT. - /// if is null. - /// if is null. - /// if is null. - /// A JWE in compact serialization format. - public virtual string CreateToken(string payload, SigningCredentials signingCredentials, EncryptingCredentials encryptingCredentials) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - if (signingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - return CreateToken(Encoding.UTF8.GetBytes(payload), signingCredentials, encryptingCredentials, null, null, null, null); - } - - /// - /// Creates a JWE (Json Web Encryption). - /// - /// A string containing JSON which represents the JWT token payload. - /// Defines the security key and algorithm that will be used to sign the JWT. - /// Defines the security key and algorithm that will be used to encrypt the JWT. - /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. - /// if is null. - /// if is null. - /// if is null. - /// if is null. - /// if , - /// , , and/or - /// are present inside of . - /// A JWE in compact serialization format. - public virtual string CreateToken( - string payload, - SigningCredentials signingCredentials, - EncryptingCredentials encryptingCredentials, - IDictionary additionalHeaderClaims) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - if (signingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - if (additionalHeaderClaims == null) - throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); - - return CreateToken(Encoding.UTF8.GetBytes(payload), signingCredentials, encryptingCredentials, null, additionalHeaderClaims, null, null); - } - - /// - /// Creates a JWE (Json Web Encryption). - /// - /// A string containing JSON which represents the JWT token payload. - /// Defines the security key and algorithm that will be used to encrypt the JWT. - /// Defines the compression algorithm that will be used to compress the JWT token payload. - /// A JWE in compact serialization format. - public virtual string CreateToken(string payload, EncryptingCredentials encryptingCredentials, string compressionAlgorithm) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - if (string.IsNullOrEmpty(compressionAlgorithm)) - throw LogHelper.LogArgumentNullException(nameof(compressionAlgorithm)); - - return CreateToken(Encoding.UTF8.GetBytes(payload), null, encryptingCredentials, compressionAlgorithm, null, null, null); - } - - /// - /// Creates a JWE (Json Web Encryption). - /// - /// A string containing JSON which represents the JWT token payload. - /// Defines the security key and algorithm that will be used to sign the JWT. - /// Defines the security key and algorithm that will be used to encrypt the JWT. - /// Defines the compression algorithm that will be used to compress the JWT token payload. - /// if is null. - /// if is null. - /// if is null. - /// if is null. - /// A JWE in compact serialization format. - public virtual string CreateToken(string payload, SigningCredentials signingCredentials, EncryptingCredentials encryptingCredentials, string compressionAlgorithm) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - if (signingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - if (string.IsNullOrEmpty(compressionAlgorithm)) - throw LogHelper.LogArgumentNullException(nameof(compressionAlgorithm)); - - return CreateToken(Encoding.UTF8.GetBytes(payload), signingCredentials, encryptingCredentials, compressionAlgorithm, null, null, null); - } - - /// - /// Creates a JWE (Json Web Encryption). - /// - /// A string containing JSON which represents the JWT token payload. - /// Defines the security key and algorithm that will be used to sign the JWT. - /// Defines the security key and algorithm that will be used to encrypt the JWT. - /// Defines the compression algorithm that will be used to compress the JWT token payload. - /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. - /// Defines the dictionary containing any custom header claims that need to be added to the inner JWT token header. - /// if is null. - /// if is null. - /// if is null. - /// if is null. - /// if is null. - /// if , - /// , , and/or - /// are present inside of . - /// A JWE in compact serialization format. - public virtual string CreateToken( - string payload, - SigningCredentials signingCredentials, - EncryptingCredentials encryptingCredentials, - string compressionAlgorithm, - IDictionary additionalHeaderClaims, - IDictionary additionalInnerHeaderClaims) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - if (signingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - if (string.IsNullOrEmpty(compressionAlgorithm)) - throw LogHelper.LogArgumentNullException(nameof(compressionAlgorithm)); - - if (additionalHeaderClaims == null) - throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); - - if (additionalInnerHeaderClaims == null) - throw LogHelper.LogArgumentNullException(nameof(additionalInnerHeaderClaims)); - - return CreateToken( - Encoding.UTF8.GetBytes(payload), - signingCredentials, - encryptingCredentials, - compressionAlgorithm, - additionalHeaderClaims, - additionalInnerHeaderClaims, - null); - } - - /// - /// Creates a JWE (Json Web Encryption). - /// - /// A string containing JSON which represents the JWT token payload. - /// Defines the security key and algorithm that will be used to sign the JWT. - /// Defines the security key and algorithm that will be used to encrypt the JWT. - /// Defines the compression algorithm that will be used to compress the JWT token payload. - /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. - /// if is null. - /// if is null. - /// if is null. - /// if is null. - /// if is null. - /// if , - /// , , and/or - /// are present inside of . - /// A JWE in compact serialization format. - public virtual string CreateToken( - string payload, - SigningCredentials signingCredentials, - EncryptingCredentials encryptingCredentials, - string compressionAlgorithm, - IDictionary additionalHeaderClaims) - { - if (string.IsNullOrEmpty(payload)) - throw LogHelper.LogArgumentNullException(nameof(payload)); - - if (signingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(signingCredentials)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - if (string.IsNullOrEmpty(compressionAlgorithm)) - throw LogHelper.LogArgumentNullException(nameof(compressionAlgorithm)); - - if (additionalHeaderClaims == null) - throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); - - return CreateToken(Encoding.UTF8.GetBytes(payload), signingCredentials, encryptingCredentials, compressionAlgorithm, additionalHeaderClaims, null, null); - } - - internal static string CreateToken - ( - byte[] payloadBytes, - SigningCredentials signingCredentials, - EncryptingCredentials encryptingCredentials, - string compressionAlgorithm, - IDictionary additionalHeaderClaims, - IDictionary additionalInnerHeaderClaims, - string tokenType) - { - // TODO - Create one Span to write everything into and avoid moving between string -> Utf8Bytes -> string + string -> Utf8bytes. - // The Header, Payload, Message, etc. - // Avoid creating and writing to MemoryStreams and then calling ToArray(); - // Rent ArrayPools - // Would like to use, but the following is internal. - // Will see how hard it is to use the code. - // ArrayBufferWriter buffer = new System.Buffers.ArrayBufferWriter(); - // C:\github\dotnet\runtime\src\libraries\Common\src\System\Buffers\ArrayBufferWriter.cs - // If we can't use ArrayBufferWriter, we can make use our own pool in tokens. - // A possibility is to use our internal pool: sealed class DisposableObjectPool where T : class, IDisposable - // When creating an Encrypted Token, pass in a Span, representing the inner token. - - if (additionalHeaderClaims?.Count > 0 && additionalHeaderClaims.Keys.Intersect(JwtTokenUtilities.DefaultHeaderParameters, StringComparer.OrdinalIgnoreCase).Any()) - throw LogHelper.LogExceptionMessage(new SecurityTokenException(LogHelper.FormatInvariant(LogMessages.IDX14116, LogHelper.MarkAsNonPII(nameof(additionalHeaderClaims)), LogHelper.MarkAsNonPII(string.Join(", ", JwtTokenUtilities.DefaultHeaderParameters))))); - - if (additionalInnerHeaderClaims?.Count > 0 && additionalInnerHeaderClaims.Keys.Intersect(JwtTokenUtilities.DefaultHeaderParameters, StringComparer.OrdinalIgnoreCase).Any()) - throw LogHelper.LogExceptionMessage(new SecurityTokenException(LogHelper.FormatInvariant(LogMessages.IDX14116, nameof(additionalInnerHeaderClaims), string.Join(", ", JwtTokenUtilities.DefaultHeaderParameters)))); - - byte[] headerBytes = WriteJwsHeader(signingCredentials, tokenType, additionalInnerHeaderClaims, encryptingCredentials == null ? additionalHeaderClaims : null); - - string header = Base64UrlEncoder.Encode(headerBytes); - string message = header + "." + Base64UrlEncoder.Encode(payloadBytes); - var rawSignature = signingCredentials == null ? string.Empty : JwtTokenUtilities.CreateEncodedSignature(message, signingCredentials); - - if (encryptingCredentials != null) - return EncryptToken(Encoding.UTF8.GetBytes(message + "." + rawSignature), encryptingCredentials, compressionAlgorithm, additionalHeaderClaims, tokenType); - - return message + "." + rawSignature; - } - - internal static byte[] CompressToken(byte[] utf8Bytes, string compressionAlgorithm) - { - if (string.IsNullOrEmpty(compressionAlgorithm)) - throw LogHelper.LogArgumentNullException(nameof(compressionAlgorithm)); - - if (!CompressionProviderFactory.Default.IsSupportedAlgorithm(compressionAlgorithm)) - throw LogHelper.LogExceptionMessage(new NotSupportedException(LogHelper.FormatInvariant(TokenLogMessages.IDX10682, LogHelper.MarkAsNonPII(compressionAlgorithm)))); - - var compressionProvider = CompressionProviderFactory.Default.CreateCompressionProvider(compressionAlgorithm); - - return compressionProvider.Compress(utf8Bytes) ?? throw LogHelper.LogExceptionMessage(new InvalidOperationException(LogHelper.FormatInvariant(TokenLogMessages.IDX10680, LogHelper.MarkAsNonPII(compressionAlgorithm)))); - } - private static StringComparison GetStringComparisonRuleIf509(SecurityKey securityKey) => (securityKey is X509SecurityKey) ? StringComparison.OrdinalIgnoreCase : StringComparison.Ordinal; @@ -1041,192 +369,6 @@ private string DecryptToken(JsonWebToken jwtToken, TokenValidationParameters val }); } - /// - /// Encrypts a JWS. - /// - /// A 'JSON Web Token' (JWT) in JWS Compact Serialization Format. - /// Defines the security key and algorithm that will be used to encrypt the . - /// if is null or empty. - /// if is null. - /// if both and . are null. - /// if the CryptoProviderFactory being used does not support the (algorithm), pair. - /// if unable to create a token encryption provider for the (algorithm), pair. - /// if encryption fails using the (algorithm), pair. - /// if not using one of the supported content encryption key (CEK) algorithms: 128, 384 or 512 AesCbcHmac (this applies in the case of key wrap only, not direct encryption). - public string EncryptToken(string innerJwt, EncryptingCredentials encryptingCredentials) - { - if (string.IsNullOrEmpty(innerJwt)) - throw LogHelper.LogArgumentNullException(nameof(innerJwt)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - return EncryptTokenPrivate(innerJwt, encryptingCredentials, null, null, null); - } - - /// - /// Encrypts a JWS. - /// - /// A 'JSON Web Token' (JWT) in JWS Compact Serialization Format. - /// Defines the security key and algorithm that will be used to encrypt the . - /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. - /// if is null or empty. - /// if is null. - /// if is null. - /// if both and . are null. - /// if the CryptoProviderFactory being used does not support the (algorithm), pair. - /// if unable to create a token encryption provider for the (algorithm), pair. - /// if encryption fails using the (algorithm), pair. - /// if not using one of the supported content encryption key (CEK) algorithms: 128, 384 or 512 AesCbcHmac (this applies in the case of key wrap only, not direct encryption). - public string EncryptToken(string innerJwt, EncryptingCredentials encryptingCredentials, IDictionary additionalHeaderClaims) - { - if (string.IsNullOrEmpty(innerJwt)) - throw LogHelper.LogArgumentNullException(nameof(innerJwt)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - if (additionalHeaderClaims == null) - throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); - - return EncryptTokenPrivate(innerJwt, encryptingCredentials, null, additionalHeaderClaims, null); - } - - /// - /// Encrypts a JWS. - /// - /// A 'JSON Web Token' (JWT) in JWS Compact Serialization Format. - /// Defines the security key and algorithm that will be used to encrypt the . - /// Defines the compression algorithm that will be used to compress the 'innerJwt'. - /// if is null or empty. - /// if is null. - /// if is null or empty. - /// if both and . are null. - /// if the CryptoProviderFactory being used does not support the (algorithm), pair. - /// if unable to create a token encryption provider for the (algorithm), pair. - /// if compression using fails. - /// if encryption fails using the (algorithm), pair. - /// if not using one of the supported content encryption key (CEK) algorithms: 128, 384 or 512 AesCbcHmac (this applies in the case of key wrap only, not direct encryption). - public string EncryptToken(string innerJwt, EncryptingCredentials encryptingCredentials, string algorithm) - { - if (string.IsNullOrEmpty(innerJwt)) - throw LogHelper.LogArgumentNullException(nameof(innerJwt)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - if (string.IsNullOrEmpty(algorithm)) - throw LogHelper.LogArgumentNullException(nameof(algorithm)); - - return EncryptTokenPrivate(innerJwt, encryptingCredentials, algorithm, null, null); - } - - /// - /// Encrypts a JWS. - /// - /// A 'JSON Web Token' (JWT) in JWS Compact Serialization Format. - /// Defines the security key and algorithm that will be used to encrypt the . - /// Defines the compression algorithm that will be used to compress the - /// Defines the dictionary containing any custom header claims that need to be added to the outer JWT token header. - /// if is null or empty. - /// if is null. - /// if is null or empty. - /// if is null or empty. - /// if both and . are null. - /// if the CryptoProviderFactory being used does not support the (algorithm), pair. - /// if unable to create a token encryption provider for the (algorithm), pair. - /// if compression using 'algorithm' fails. - /// if encryption fails using the (algorithm), pair. - /// if not using one of the supported content encryption key (CEK) algorithms: 128, 384 or 512 AesCbcHmac (this applies in the case of key wrap only, not direct encryption). - public string EncryptToken(string innerJwt, - EncryptingCredentials encryptingCredentials, - string algorithm, - IDictionary additionalHeaderClaims) - { - if (string.IsNullOrEmpty(innerJwt)) - throw LogHelper.LogArgumentNullException(nameof(innerJwt)); - - if (encryptingCredentials == null) - throw LogHelper.LogArgumentNullException(nameof(encryptingCredentials)); - - if (string.IsNullOrEmpty(algorithm)) - throw LogHelper.LogArgumentNullException(nameof(algorithm)); - - if (additionalHeaderClaims == null) - throw LogHelper.LogArgumentNullException(nameof(additionalHeaderClaims)); - - return EncryptTokenPrivate(innerJwt, encryptingCredentials, algorithm, additionalHeaderClaims, null); - } - - private static string EncryptTokenPrivate( - string innerJwt, - EncryptingCredentials encryptingCredentials, - string compressionAlgorithm, - IDictionary additionalHeaderClaims, - string tokenType) - { - return (EncryptToken( - Encoding.UTF8.GetBytes(innerJwt), - encryptingCredentials, - compressionAlgorithm, - additionalHeaderClaims, - tokenType)); - } - - internal static string EncryptToken( - byte[] innerTokenUtf8Bytes, - EncryptingCredentials encryptingCredentials, - string compressionAlgorithm, - IDictionary additionalHeaderClaims, - string tokenType) - { - CryptoProviderFactory cryptoProviderFactory = encryptingCredentials.CryptoProviderFactory ?? encryptingCredentials.Key.CryptoProviderFactory; - - if (cryptoProviderFactory == null) - throw LogHelper.LogExceptionMessage(new ArgumentException(TokenLogMessages.IDX10620)); - - SecurityKey securityKey = JwtTokenUtilities.GetSecurityKey(encryptingCredentials, cryptoProviderFactory, additionalHeaderClaims, out byte[] wrappedKey); - - using (AuthenticatedEncryptionProvider encryptionProvider = cryptoProviderFactory.CreateAuthenticatedEncryptionProvider(securityKey, encryptingCredentials.Enc)) - { - if (encryptionProvider == null) - throw LogHelper.LogExceptionMessage(new SecurityTokenEncryptionFailedException(LogMessages.IDX14103)); - - byte[] jweHeader = WriteJweHeader(encryptingCredentials, compressionAlgorithm, tokenType, additionalHeaderClaims); - byte[] plainText; - if (!string.IsNullOrEmpty(compressionAlgorithm)) - { - try - { - plainText = CompressToken(innerTokenUtf8Bytes, compressionAlgorithm); - } - catch (Exception ex) - { - throw LogHelper.LogExceptionMessage(new SecurityTokenCompressionFailedException(LogHelper.FormatInvariant(TokenLogMessages.IDX10680, LogHelper.MarkAsNonPII(compressionAlgorithm)), ex)); - } - } - else - { - plainText = innerTokenUtf8Bytes; - } - - try - { - string rawHeader = Base64UrlEncoder.Encode(jweHeader); - - //TODO - why isn't the result checked. - var encryptionResult = encryptionProvider.Encrypt(plainText, Encoding.ASCII.GetBytes(rawHeader)); - return JwtConstants.DirectKeyUseAlg.Equals(encryptingCredentials.Alg) ? - string.Join(".", rawHeader, string.Empty, Base64UrlEncoder.Encode(encryptionResult.IV), Base64UrlEncoder.Encode(encryptionResult.Ciphertext), Base64UrlEncoder.Encode(encryptionResult.AuthenticationTag)) : - string.Join(".", rawHeader, Base64UrlEncoder.Encode(wrappedKey), Base64UrlEncoder.Encode(encryptionResult.IV), Base64UrlEncoder.Encode(encryptionResult.Ciphertext), Base64UrlEncoder.Encode(encryptionResult.AuthenticationTag)); - } - catch (Exception ex) - { - throw LogHelper.LogExceptionMessage(new SecurityTokenEncryptionFailedException(LogHelper.FormatInvariant(TokenLogMessages.IDX10616, LogHelper.MarkAsNonPII(encryptingCredentials.Enc), encryptingCredentials.Key), ex)); - } - } - } - private static SecurityKey ResolveTokenDecryptionKeyFromConfig(JsonWebToken jwtToken, BaseConfiguration configuration) { if (jwtToken == null) @@ -1257,102 +399,6 @@ private static SecurityKey ResolveTokenDecryptionKeyFromConfig(JsonWebToken jwtT return null; } - internal IEnumerable GetContentEncryptionKeys(JsonWebToken jwtToken, TokenValidationParameters validationParameters, BaseConfiguration configuration) - { - IEnumerable keys = null; - - // First we check to see if the caller has set a custom decryption resolver on TVP for the call, if so any keys set on TVP and keys in Configuration are ignored. - // If no custom decryption resolver set, we'll check to see if they've set some static decryption keys on TVP. If a key found, we ignore configuration. - // If no key found in TVP, we'll check the configuration. - if (validationParameters.TokenDecryptionKeyResolver != null) - { - keys = validationParameters.TokenDecryptionKeyResolver(jwtToken.EncodedToken, jwtToken, jwtToken.Kid, validationParameters); - } - else - { - var key = ResolveTokenDecryptionKey(jwtToken.EncodedToken, jwtToken, validationParameters); - if (key != null) - { - if (LogHelper.IsEnabled(EventLogLevel.Informational)) - LogHelper.LogInformation(TokenLogMessages.IDX10904, key); - } - else if (configuration != null) - { - key = ResolveTokenDecryptionKeyFromConfig(jwtToken, configuration); - if (key != null && LogHelper.IsEnabled(EventLogLevel.Informational)) - LogHelper.LogInformation(TokenLogMessages.IDX10905, key); - } - - if (key != null) - keys = new List { key }; - } - - // on decryption for ECDH-ES, we get the public key from the EPK value see: https://datatracker.ietf.org/doc/html/rfc7518#appendix-C - // we need the ECDSASecurityKey for the receiver, use TokenValidationParameters.TokenDecryptionKey - - // control gets here if: - // 1. User specified delegate: TokenDecryptionKeyResolver returned null - // 2. ResolveTokenDecryptionKey returned null - // 3. ResolveTokenDecryptionKeyFromConfig returned null - // Try all the keys. This is the degenerate case, not concerned about perf. - if (keys == null) - { - keys = JwtTokenUtilities.GetAllDecryptionKeys(validationParameters); - if (configuration != null) - keys = keys == null ? configuration.TokenDecryptionKeys : keys.Concat(configuration.TokenDecryptionKeys); - } - - if (jwtToken.Alg.Equals(JwtConstants.DirectKeyUseAlg, StringComparison.Ordinal) - || jwtToken.Alg.Equals(SecurityAlgorithms.EcdhEs, StringComparison.Ordinal)) - return keys; - - var unwrappedKeys = new List(); - // keep track of exceptions thrown, keys that were tried - StringBuilder exceptionStrings = null; - StringBuilder keysAttempted = null; - foreach (var key in keys) - { - try - { -#if NET472 || NET6_0_OR_GREATER - if (SupportedAlgorithms.EcdsaWrapAlgorithms.Contains(jwtToken.Alg)) - { - // on decryption we get the public key from the EPK value see: https://datatracker.ietf.org/doc/html/rfc7518#appendix-C - var ecdhKeyExchangeProvider = new EcdhKeyExchangeProvider( - key as ECDsaSecurityKey, - validationParameters.TokenDecryptionKey as ECDsaSecurityKey, - jwtToken.Alg, - jwtToken.Enc); - jwtToken.TryGetHeaderValue(JwtHeaderParameterNames.Apu, out string apu); - jwtToken.TryGetHeaderValue(JwtHeaderParameterNames.Apv, out string apv); - SecurityKey kdf = ecdhKeyExchangeProvider.GenerateKdf(apu, apv); - var kwp = key.CryptoProviderFactory.CreateKeyWrapProviderForUnwrap(kdf, ecdhKeyExchangeProvider.GetEncryptionAlgorithm()); - var unwrappedKey = kwp.UnwrapKey(Base64UrlEncoder.DecodeBytes(jwtToken.EncryptedKey)); - unwrappedKeys.Add(new SymmetricSecurityKey(unwrappedKey)); - } - else -#endif - if (key.CryptoProviderFactory.IsSupportedAlgorithm(jwtToken.Alg, key)) - { - var kwp = key.CryptoProviderFactory.CreateKeyWrapProviderForUnwrap(key, jwtToken.Alg); - var unwrappedKey = kwp.UnwrapKey(jwtToken.EncryptedKeyBytes); - unwrappedKeys.Add(new SymmetricSecurityKey(unwrappedKey)); - } - } - catch (Exception ex) - { - (exceptionStrings ??= new StringBuilder()).AppendLine(ex.ToString()); - } - - (keysAttempted ??= new StringBuilder()).AppendLine(key.ToString()); - } - - if (unwrappedKeys.Count > 0 && exceptionStrings is null) - return unwrappedKeys; - else - throw LogHelper.LogExceptionMessage(new SecurityTokenKeyWrapException(LogHelper.FormatInvariant(TokenLogMessages.IDX10618, (object)keysAttempted ?? "", (object)exceptionStrings ?? "", jwtToken))); - } - /// /// Returns a to use when decrypting a JWE. /// diff --git a/src/Microsoft.IdentityModel.JsonWebTokens/JwtTokenUtilities.cs b/src/Microsoft.IdentityModel.JsonWebTokens/JwtTokenUtilities.cs index 34534e3825..2b19867a8e 100644 --- a/src/Microsoft.IdentityModel.JsonWebTokens/JwtTokenUtilities.cs +++ b/src/Microsoft.IdentityModel.JsonWebTokens/JwtTokenUtilities.cs @@ -8,12 +8,10 @@ using System.Security.Claims; using System.Security.Cryptography; using System.Text; -using System.Text.Json; using System.Text.RegularExpressions; using Microsoft.IdentityModel.Abstractions; using Microsoft.IdentityModel.Logging; using Microsoft.IdentityModel.Tokens; -using Microsoft.IdentityModel.Tokens.Json; using TokenLogMessages = Microsoft.IdentityModel.Tokens.LogMessages; @@ -47,7 +45,7 @@ public partial class JwtTokenUtilities private static Regex CreateJweRegex() => new Regex(JwtConstants.JweCompactSerializationRegex, RegexOptions.Compiled | RegexOptions.CultureInvariant, TimeSpan.FromMilliseconds(_regexMatchTimeoutMilliseconds)); #endif - internal static IList DefaultHeaderParameters = new List() + internal static List DefaultHeaderParameters = new List() { JwtHeaderParameterNames.Alg, JwtHeaderParameterNames.Kid, @@ -121,6 +119,82 @@ public static string CreateEncodedSignature(string input, SigningCredentials sig } } + internal static byte[] CreateEncodedSignature( + byte[] input, + int offset, + int count, + SigningCredentials signingCredentials) + { + if (input == null) + throw LogHelper.LogArgumentNullException(nameof(input)); + + if (signingCredentials == null) + return null; + + var cryptoProviderFactory = signingCredentials.CryptoProviderFactory ?? signingCredentials.Key.CryptoProviderFactory; + var signatureProvider = cryptoProviderFactory.CreateForSigning(signingCredentials.Key, signingCredentials.Algorithm) ?? + throw LogHelper.LogExceptionMessage( + new InvalidOperationException( + LogHelper.FormatInvariant( + TokenLogMessages.IDX10637, + signingCredentials.Key == null ? "Null" : signingCredentials.Key.ToString(), + LogHelper.MarkAsNonPII(signingCredentials.Algorithm)))); + + try + { + if (LogHelper.IsEnabled(EventLogLevel.Verbose)) + LogHelper.LogVerbose(LogMessages.IDX14200); + + return signatureProvider.Sign(input, offset, count); + } + finally + { + cryptoProviderFactory.ReleaseSignatureProvider(signatureProvider); + } + } + +#if NET6_0_OR_GREATER + /// + /// Produces a signature over the . + /// + /// Span containing bytes to be signed. + /// destination for signature. + /// The that contain crypto specs used to sign the token. + /// + /// The size of the signature. + /// 'input' or 'signingCredentials' is null. + internal static bool CreateSignature( + ReadOnlySpan data, + Span destination, + SigningCredentials signingCredentials, + out int bytesWritten) + { + bytesWritten = 0; + if (signingCredentials == null) + return false; + + var cryptoProviderFactory = signingCredentials.CryptoProviderFactory ?? signingCredentials.Key.CryptoProviderFactory; + var signatureProvider = cryptoProviderFactory.CreateForSigning(signingCredentials.Key, signingCredentials.Algorithm) ?? + throw LogHelper.LogExceptionMessage( + new InvalidOperationException( + LogHelper.FormatInvariant( + TokenLogMessages.IDX10637, signingCredentials.Key == null ? "Null" : signingCredentials.Key.ToString(), + LogHelper.MarkAsNonPII(signingCredentials.Algorithm)))); + + try + { + if (LogHelper.IsEnabled(EventLogLevel.Verbose)) + LogHelper.LogVerbose(LogMessages.IDX14200); + + return signatureProvider.Sign(data, destination, out bytesWritten); + } + finally + { + cryptoProviderFactory.ReleaseSignatureProvider(signatureProvider); + } + } +#endif + /// /// Decompress JWT token bytes. /// @@ -419,7 +493,7 @@ internal static string SafeLogJwtToken(object obj) // not a string, we do not know how to sanitize so we return a String which represents the object instance if (!(obj is string token)) return obj.GetType().ToString(); - + int lastDot = token.LastIndexOf("."); // no dots, not a JWT, we do not know how to sanitize so we return UnrecognizedEncodedToken diff --git a/src/Microsoft.IdentityModel.Tokens/AsymmetricAdapter.cs b/src/Microsoft.IdentityModel.Tokens/AsymmetricAdapter.cs index 1fb85e505f..9ad8fd0f7a 100644 --- a/src/Microsoft.IdentityModel.Tokens/AsymmetricAdapter.cs +++ b/src/Microsoft.IdentityModel.Tokens/AsymmetricAdapter.cs @@ -10,8 +10,12 @@ namespace Microsoft.IdentityModel.Tokens delegate byte[] EncryptDelegate(byte[] bytes); delegate byte[] DecryptDelegate(byte[] bytes); delegate byte[] SignDelegate(byte[] bytes); + delegate byte[] SignUsingOffsetDelegate(byte[] bytes, int offset, int count); +#if NET6_0_OR_GREATER + delegate bool SignUsingSpanDelegate(ReadOnlySpan bytes, Span signature, out int bytesWritten); +#endif delegate bool VerifyDelegate(byte[] bytes, byte[] signature); - delegate bool VerifyDelegateWithLength(byte[] bytes, int start, int length, byte[] signature); + delegate bool VerifyUsingOffsetDelegate(byte[] bytes, int offset, int count, byte[] signature); /// /// This adapter abstracts the 'RSA' differences between versions of .Net targets. @@ -23,11 +27,15 @@ internal class AsymmetricAdapter : IDisposable #endif private bool _disposeCryptoOperators = false; private bool _disposed = false; - private DecryptDelegate DecryptFunction = DecryptFunctionNotFound; - private EncryptDelegate EncryptFunction = EncryptFunctionNotFound; - private SignDelegate SignatureFunction = SignatureFunctionNotFound; - private VerifyDelegate VerifyFunction = VerifyFunctionNotFound; - private VerifyDelegateWithLength VerifyFunctionWithLength = VerifyFunctionWithLengthNotFound; + private DecryptDelegate _decryptFunction = DecryptFunctionNotFound; + private EncryptDelegate _encryptFunction = EncryptFunctionNotFound; + private SignDelegate _signFunction = SignFunctionNotFound; + private SignUsingOffsetDelegate _signUsingOffsetFunction = SignUsingOffsetNotFound; +#if NET6_0_OR_GREATER + private SignUsingSpanDelegate _signUsingSpanFunction = SignUsingSpanNotFound; +#endif + private VerifyDelegate _verifyFunction = VerifyNotFound; + private VerifyUsingOffsetDelegate _verifyUsingOffsetFunction = VerifyUsingOffsetNotFound; // Encryption algorithms do not need a HashAlgorithm, this is called by RSAKeyWrap internal AsymmetricAdapter(SecurityKey key, string algorithm, bool requirePrivateKey) @@ -35,6 +43,13 @@ internal AsymmetricAdapter(SecurityKey key, string algorithm, bool requirePrivat { } + internal AsymmetricAdapter(SecurityKey key, string algorithm, HashAlgorithm hashAlgorithm, HashAlgorithmName hashAlgorithmName, bool requirePrivateKey) + : this(key, algorithm, hashAlgorithm, requirePrivateKey) + { + + HashAlgorithmName = hashAlgorithmName; + } + internal AsymmetricAdapter(SecurityKey key, string algorithm, HashAlgorithm hashAlgorithm, bool requirePrivateKey) { HashAlgorithm = hashAlgorithm; @@ -73,7 +88,7 @@ internal AsymmetricAdapter(SecurityKey key, string algorithm, HashAlgorithm hash internal byte[] Decrypt(byte[] data) { - return DecryptFunction(data); + return _decryptFunction(data); } internal static byte[] DecryptFunctionNotFound(byte[] _) @@ -117,7 +132,7 @@ protected virtual void Dispose(bool disposing) internal byte[] Encrypt(byte[] data) { - return EncryptFunction(data); + return _encryptFunction(data); } internal static byte[] EncryptFunctionNotFound(byte[] _) @@ -131,16 +146,20 @@ internal static byte[] EncryptFunctionNotFound(byte[] _) private void InitializeUsingEcdsaSecurityKey(ECDsaSecurityKey ecdsaSecurityKey) { ECDsa = ecdsaSecurityKey.ECDsa; - SignatureFunction = SignWithECDsa; - VerifyFunction = VerifyWithECDsa; - VerifyFunctionWithLength = VerifyWithECDsaWithLength; + _signFunction = SignECDsa; + _signUsingOffsetFunction = SignUsingOffsetECDsa; +#if NET6_0_OR_GREATER + _signUsingSpanFunction = SignUsingSpanECDsa; +#endif + _verifyFunction = VerifyECDsa; + _verifyUsingOffsetFunction = VerifyUsingOffsetECDsa; } private void InitializeUsingRsa(RSA rsa, string algorithm) { // The return value for X509Certificate2.GetPrivateKey OR X509Certificate2.GetPublicKey.Key is a RSACryptoServiceProvider // These calls return an AsymmetricAlgorithm which doesn't have API's to do much and need to be cast. - // RSACryptoServiceProvider is wrapped with RSACryptoServiceProviderProxy as some CryptoServideProviders (CSP's) do + // RSACryptoServiceProvider is wrapped with RSACryptoServiceProviderProxy as some CryptoServiceProviders (CSP's) do // not natively support SHA2. #if DESKTOP if (rsa is RSACryptoServiceProvider rsaCryptoServiceProvider) @@ -149,13 +168,12 @@ private void InitializeUsingRsa(RSA rsa, string algorithm) || algorithm.Equals(SecurityAlgorithms.RsaOaepKeyWrap); RsaCryptoServiceProviderProxy = new RSACryptoServiceProviderProxy(rsaCryptoServiceProvider); - DecryptFunction = DecryptWithRsaCryptoServiceProviderProxy; - EncryptFunction = EncryptWithRsaCryptoServiceProviderProxy; - SignatureFunction = SignWithRsaCryptoServiceProviderProxy; - VerifyFunction = VerifyWithRsaCryptoServiceProviderProxy; -#if NET461_OR_GREATER - VerifyFunctionWithLength = VerifyWithRsaCryptoServiceProviderProxyWithLength; -#endif + _decryptFunction = DecryptWithRsaCryptoServiceProviderProxy; + _encryptFunction = EncryptWithRsaCryptoServiceProviderProxy; + _signFunction = SignWithRsaCryptoServiceProviderProxy; + _signUsingOffsetFunction = SignWithRsaCryptoServiceProviderProxyUsingOffset; + _verifyFunction = VerifyWithRsaCryptoServiceProviderProxy; + _verifyUsingOffsetFunction = VerifyWithRsaCryptoServiceProviderProxyUsingOffset; // RSACryptoServiceProviderProxy will track if a new RSA object is created and dispose appropriately. _disposeCryptoOperators = true; return; @@ -181,11 +199,15 @@ private void InitializeUsingRsa(RSA rsa, string algorithm) ? RSAEncryptionPadding.OaepSHA1 : RSAEncryptionPadding.Pkcs1; RSA = rsa; - DecryptFunction = DecryptWithRsa; - EncryptFunction = EncryptWithRsa; - SignatureFunction = SignWithRsa; - VerifyFunction = VerifyWithRsa; - VerifyFunctionWithLength = VerifyWithRsaWithLength; + _decryptFunction = DecryptWithRsa; + _encryptFunction = EncryptWithRsa; + _signFunction = SignRsa; + _signUsingOffsetFunction = SignUsingOffsetRsa; +#if NET6_0_OR_GREATER + _signUsingSpanFunction = SignUsingSpanRsa; +#endif + _verifyFunction = VerifyRsa; + _verifyUsingOffsetFunction = VerifyUsingOffsetRsa; } private void InitializeUsingRsaSecurityKey(RsaSecurityKey rsaSecurityKey, string algorithm) @@ -219,60 +241,101 @@ private void InitializeUsingX509SecurityKey(X509SecurityKey x509SecurityKey, str internal byte[] Sign(byte[] bytes) { - return SignatureFunction(bytes); + return _signFunction(bytes); + } + +#if NET6_0_OR_GREATER + internal bool SignUsingSpan(ReadOnlySpan data, Span destination, out int bytesWritten) + { + return _signUsingSpanFunction(data, destination, out bytesWritten); + } +#endif + + internal byte[] SignUsingOffset(byte[] bytes, int offset, int count) + { + return _signUsingOffsetFunction(bytes, offset, count); } - private static byte[] SignatureFunctionNotFound(byte[] _) + private static byte[] SignFunctionNotFound(byte[] _) { // we should never get here, its a bug if we do. throw LogHelper.LogExceptionMessage(new CryptographicException(LogMessages.IDX10685)); } - private byte[] SignWithECDsa(byte[] bytes) + private static byte[] SignUsingOffsetNotFound(byte[] b, int c, int d) + { + // we should never get here, its a bug if we do. + throw LogHelper.LogExceptionMessage(new CryptographicException(LogMessages.IDX10685)); + } + +#if NET6_0_OR_GREATER +#pragma warning disable CA1801 // Review unused parameters + private static bool SignUsingSpanNotFound(ReadOnlySpan data, Span destination, out int bytesWritten) +#pragma warning restore CA1801 // Review unused parameters + { + // we should never get here, its a bug if we do. + throw LogHelper.LogExceptionMessage(new CryptographicException(LogMessages.IDX10685)); + } +#endif + + private byte[] SignECDsa(byte[] bytes) { return ECDsa.SignHash(HashAlgorithm.ComputeHash(bytes)); } +#if NET6_0_OR_GREATER + internal bool SignUsingSpanECDsa(ReadOnlySpan data, Span destination, out int bytesWritten) + { + // ECDSA.TrySignData will return true and set bytesWritten = 64, if destination is null. + if (destination.Length == 0) + { + bytesWritten = 0; + return false; + } + + bool success = ECDsa.TrySignData(data, destination, HashAlgorithmName, out bytesWritten); + if (!success || bytesWritten == 0) + return false; + + return destination.Length >= bytesWritten; + } +#endif + + private byte[] SignUsingOffsetECDsa(byte[] bytes, int offset, int count) + { + return ECDsa.SignHash(HashAlgorithm.ComputeHash(bytes, offset, count)); + } + internal bool Verify(byte[] bytes, byte[] signature) { - return VerifyFunction(bytes, signature); + return _verifyFunction(bytes, signature); } - internal bool Verify(byte[] bytes, int start, int length, byte[] signature) + internal bool VerifyUsingOffset(byte[] bytes, int offset, int count, byte[] signature) { - return VerifyFunctionWithLength(bytes, start, length, signature); + return _verifyUsingOffsetFunction(bytes, offset, count, signature); } - private static bool VerifyFunctionNotFound(byte[] bytes, byte[] signature) + private static bool VerifyNotFound(byte[] bytes, byte[] signature) { // we should never get here, its a bug if we do. throw LogHelper.LogExceptionMessage(new NotSupportedException(LogMessages.IDX10686)); } - private static bool VerifyFunctionWithLengthNotFound(byte[] bytes, int start, int length, byte[] signature) + private static bool VerifyUsingOffsetNotFound(byte[] bytes, int offset, int count, byte[] signature) { // we should never get here, its a bug if we do. throw LogHelper.LogExceptionMessage(new NotSupportedException(LogMessages.IDX10686)); } - private bool VerifyWithECDsa(byte[] bytes, byte[] signature) + private bool VerifyECDsa(byte[] bytes, byte[] signature) { return ECDsa.VerifyHash(HashAlgorithm.ComputeHash(bytes), signature); } - private bool VerifyWithECDsaWithLength(byte[] bytes, int start, int length, byte[] signature) + private bool VerifyUsingOffsetECDsa(byte[] bytes, int offset, int count, byte[] signature) { - return ECDsa.VerifyHash(HashAlgorithm.ComputeHash(bytes, start, length), signature); - } - -#region NET61+ related code -#if NET461 || NET462 || NET472 || NETSTANDARD2_0 || NET6_0_OR_GREATER - - // HasAlgorithmName was introduced into Net46 - internal AsymmetricAdapter(SecurityKey key, string algorithm, HashAlgorithm hashAlgorithm, HashAlgorithmName hashAlgorithmName, bool requirePrivateKey) - : this(key, algorithm, hashAlgorithm, requirePrivateKey) - { - HashAlgorithmName = hashAlgorithmName; + return ECDsa.VerifyHash(HashAlgorithm.ComputeHash(bytes, offset, count), signature); } private byte[] DecryptWithRsa(byte[] bytes) @@ -291,22 +354,32 @@ private byte[] EncryptWithRsa(byte[] bytes) private RSASignaturePadding RSASignaturePadding { get; set; } - private byte[] SignWithRsa(byte[] bytes) + private byte[] SignRsa(byte[] bytes) { return RSA.SignHash(HashAlgorithm.ComputeHash(bytes), HashAlgorithmName, RSASignaturePadding); } - private bool VerifyWithRsa(byte[] bytes, byte[] signature) +#if NET6_0_OR_GREATER + internal bool SignUsingSpanRsa(ReadOnlySpan data, Span destination, out int bytesWritten) + { + return RSA.TrySignData(data, destination, HashAlgorithmName, RSASignaturePadding, out bytesWritten); + } +#endif + + private byte[] SignUsingOffsetRsa(byte[] bytes, int offset, int count) + { + return RSA.SignData(bytes, offset, count, HashAlgorithmName, RSASignaturePadding); + } + + private bool VerifyRsa(byte[] bytes, byte[] signature) { return RSA.VerifyHash(HashAlgorithm.ComputeHash(bytes), signature, HashAlgorithmName, RSASignaturePadding); } - private bool VerifyWithRsaWithLength(byte[] bytes, int start, int length, byte[] signature) + private bool VerifyUsingOffsetRsa(byte[] bytes, int offset, int count, byte[] signature) { - return RSA.VerifyHash(HashAlgorithm.ComputeHash(bytes, start, length), signature, HashAlgorithmName, RSASignaturePadding); + return RSA.VerifyHash(HashAlgorithm.ComputeHash(bytes, offset, count), signature, HashAlgorithmName, RSASignaturePadding); } -#endif -#endregion #region DESKTOP related code #if DESKTOP @@ -326,18 +399,20 @@ internal byte[] SignWithRsaCryptoServiceProviderProxy(byte[] bytes) { return RsaCryptoServiceProviderProxy.SignData(bytes, HashAlgorithm); } + internal byte[] SignWithRsaCryptoServiceProviderProxyUsingOffset(byte[] bytes, int offset, int length) + { + return RsaCryptoServiceProviderProxy.SignData(bytes, offset, length, HashAlgorithm); + } private bool VerifyWithRsaCryptoServiceProviderProxy(byte[] bytes, byte[] signature) { return RsaCryptoServiceProviderProxy.VerifyData(bytes, HashAlgorithm, signature); } - #if NET461_OR_GREATER - private bool VerifyWithRsaCryptoServiceProviderProxyWithLength(byte[] bytes, int offset, int length, byte[] signature) + private bool VerifyWithRsaCryptoServiceProviderProxyUsingOffset(byte[] bytes, int offset, int length, byte[] signature) { return RsaCryptoServiceProviderProxy.VerifyDataWithLength(bytes, offset, length, HashAlgorithm, HashAlgorithmName, signature); } - #endif #endif #endregion diff --git a/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs b/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs index e4c2f3534b..4ad0cf16fb 100644 --- a/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs +++ b/src/Microsoft.IdentityModel.Tokens/AsymmetricSignatureProvider.cs @@ -194,6 +194,46 @@ internal bool ValidKeySize() /// internal override int ObjectPoolSize => _asymmetricAdapterObjectPool.Size; +#if NET6_0_OR_GREATER + /// + /// This must be overridden to produce a signature over the 'input'. + /// + /// bytes to sign. + /// pre allocated span where signature bytes will be placed. + /// number of bytes written into the signature span. + /// returns true if creation of signature succeeded, false otherwise. + internal override bool Sign(ReadOnlySpan input, Span signature, out int bytesWritten) + { + if (input == null || input.Length == 0) + throw LogHelper.LogArgumentNullException(nameof(input)); + + if (_disposed) + { + CryptoProviderCache?.TryRemove(this); + throw LogHelper.LogExceptionMessage(new ObjectDisposedException(GetType().ToString())); + } + + AsymmetricAdapter asym = null; + try + { + asym = _asymmetricAdapterObjectPool.Allocate(); + return asym.SignUsingSpan(input, signature, out bytesWritten); + } + catch + { + CryptoProviderCache?.TryRemove(this); + Dispose(true); + throw; + } + finally + { + if (!_disposed) + _asymmetricAdapterObjectPool.Free(asym); + } + + } +#endif + /// /// Produces a signature over the 'input' using the and algorithm passed to . /// @@ -233,6 +273,36 @@ public override byte[] Sign(byte[] input) } } + internal override byte[] Sign(byte[] input, int offset, int count) + { + if (input == null || input.Length == 0) + throw LogHelper.LogArgumentNullException(nameof(input)); + + if (_disposed) + { + CryptoProviderCache?.TryRemove(this); + throw LogHelper.LogExceptionMessage(new ObjectDisposedException(GetType().ToString())); + } + + AsymmetricAdapter asym = null; + try + { + asym = _asymmetricAdapterObjectPool.Allocate(); + return asym.SignUsingOffset(input, offset, count); + } + catch + { + CryptoProviderCache?.TryRemove(this); + Dispose(true); + throw; + } + finally + { + if (!_disposed) + _asymmetricAdapterObjectPool.Free(asym); + } + } + /// /// Validates that an asymmetric key size is of sufficient size for a SignatureAlgorithm. /// @@ -406,7 +476,7 @@ public override bool Verify(byte[] input, int inputOffset, int inputLength, byte asym = _asymmetricAdapterObjectPool.Allocate(); if (signature.Length == signatureLength) { - return asym.Verify(input, inputOffset, inputLength, signature); + return asym.VerifyUsingOffset(input, inputOffset, inputLength, signature); } else { @@ -414,7 +484,7 @@ public override bool Verify(byte[] input, int inputOffset, int inputLength, byte // Having the logic here, handles EC and RSA. We can revisit when we start using spans in 3.1+. byte[] signatureBytes = new byte[signatureLength]; Array.Copy(signature, 0, signatureBytes, 0, signatureLength); - return asym.Verify(input, inputOffset, inputLength, signatureBytes); + return asym.VerifyUsingOffset(input, inputOffset, inputLength, signatureBytes); } } catch diff --git a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs index d30ec32dc8..db6d1bc0d2 100644 --- a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs +++ b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoder.cs @@ -37,6 +37,20 @@ public static string Encode(string arg) return Encode(Encoding.UTF8.GetBytes(arg)); } + /// + /// Converts a subset of an array of 8-bit unsigned integers to its equivalent string representation which is encoded with base-64-url digits. + /// + /// An array of 8-bit unsigned integers. + /// The string representation in base 64 url encoding of length elements of inArray, starting at position offset. + /// 'inArray' is null. + /// offset or length is negative OR offset plus length is greater than the length of inArray. + public static string Encode(byte[] inArray) + { + _ = inArray ?? throw LogHelper.LogArgumentNullException(nameof(inArray)); + + return Encode(inArray, 0, inArray.Length); + } + /// /// Converts a subset of an array of 8-bit unsigned integers to its equivalent string representation which is encoded with base-64-url digits. Parameters specify /// the subset as an offset in the input array, and the number of elements in the array to convert. @@ -53,6 +67,7 @@ public static string Encode(byte[] inArray, int offset, int length) if (offset < 0) throw LogHelper.LogExceptionMessage(new ArgumentOutOfRangeException( + nameof(offset), LogHelper.FormatInvariant( LogMessages.IDX10716, LogHelper.MarkAsNonPII(nameof(offset)), @@ -63,13 +78,16 @@ public static string Encode(byte[] inArray, int offset, int length) if (length < 0) throw LogHelper.LogExceptionMessage(new ArgumentOutOfRangeException( + nameof(length), LogHelper.FormatInvariant( LogMessages.IDX10716, LogHelper.MarkAsNonPII(nameof(length)), LogHelper.MarkAsNonPII(length)))); if (inArray.Length < offset + length) +#pragma warning disable CA2208 // Instantiate argument exceptions correctly throw LogHelper.LogExceptionMessage(new ArgumentOutOfRangeException( + "offset + length", LogHelper.FormatInvariant( LogMessages.IDX10717, LogHelper.MarkAsNonPII(nameof(offset)), @@ -78,16 +96,31 @@ public static string Encode(byte[] inArray, int offset, int length) LogHelper.MarkAsNonPII(offset), LogHelper.MarkAsNonPII(length), LogHelper.MarkAsNonPII(inArray.Length)))); +#pragma warning restore CA2208 // Instantiate argument exceptions correctly + + char[] destination = new char[(inArray.Length + 2) / 3 * 4]; + int j = Encode(inArray.AsSpan().Slice(offset, length), destination.AsSpan()); - int lengthmod3 = length % 3; - int limit = offset + (length - lengthmod3); - char[] output = new char[(length + 2) / 3 * 4]; + return new string(destination, 0, j); + } + + /// + /// Populates a Converts a encoded with base-64-url digits. Parameters specify + /// the subset as an offset in the input array, and the number of elements in the array to convert. + /// + /// A span of bytes. + /// output for encoding. + /// The number of chars written to the output. + public static int Encode(ReadOnlySpan inArray, Span output) + { + int lengthmod3 = inArray.Length % 3; + int limit = (inArray.Length - lengthmod3); ReadOnlySpan table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"u8; int i, j = 0; // takes 3 bytes from inArray and insert 4 bytes into output - for (i = offset; i < limit; i += 3) + for (i = 0; i < limit; i += 3) { byte d0 = inArray[i]; byte d1 = inArray[i + 1]; @@ -130,28 +163,7 @@ public static string Encode(byte[] inArray, int offset, int length) //default or case 0: no further operations are needed. } - return new string(output, 0, j); - } - - /// - /// Converts a subset of an array of 8-bit unsigned integers to its equivalent string representation which is encoded with base-64-url digits. - /// - /// An array of 8-bit unsigned integers. - /// The string representation in base 64 url encoding of length elements of inArray, starting at position offset. - /// 'inArray' is null. - /// offset or length is negative OR offset plus length is greater than the length of inArray. - public static string Encode(byte[] inArray) - { - _ = inArray ?? throw LogHelper.LogArgumentNullException(nameof(inArray)); - - return Encode(inArray, 0, inArray.Length); - } - - internal static string EncodeString(string str) - { - _ = str ?? throw LogHelper.LogArgumentNullException(nameof(str)); - - return Encode(Encoding.UTF8.GetBytes(str)); + return j; } /// diff --git a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs index 05d14c9255..4c5b6effe6 100644 --- a/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs +++ b/src/Microsoft.IdentityModel.Tokens/Base64UrlEncoding.cs @@ -286,6 +286,7 @@ public static string Encode(byte[] input, int offset, int length) if (length < 0) throw LogHelper.LogExceptionMessage(new ArgumentOutOfRangeException( + nameof(length), LogHelper.FormatInvariant( LogMessages.IDX10716, LogHelper.MarkAsNonPII(nameof(length)), @@ -293,13 +294,16 @@ public static string Encode(byte[] input, int offset, int length) if (offset < 0) throw LogHelper.LogExceptionMessage(new ArgumentOutOfRangeException( + nameof(offset), LogHelper.FormatInvariant( LogMessages.IDX10716, LogHelper.MarkAsNonPII(nameof(offset)), LogHelper.MarkAsNonPII(offset)))); if (input.Length < offset + length) +#pragma warning disable CA2208 // Instantiate argument exceptions correctly throw LogHelper.LogExceptionMessage(new ArgumentOutOfRangeException( + "offset + length", LogHelper.FormatInvariant( LogMessages.IDX10717, LogHelper.MarkAsNonPII(nameof(offset)), @@ -308,6 +312,7 @@ public static string Encode(byte[] input, int offset, int length) LogHelper.MarkAsNonPII(offset), LogHelper.MarkAsNonPII(length), LogHelper.MarkAsNonPII(input.Length)))); +#pragma warning restore CA2208 // Instantiate argument exceptions correctly int outputsize = length % 3; if (outputsize > 0) diff --git a/src/Microsoft.IdentityModel.Tokens/RsaCryptoServiceProviderProxy.cs b/src/Microsoft.IdentityModel.Tokens/RsaCryptoServiceProviderProxy.cs index 5cbdb9183e..ea2adbee28 100644 --- a/src/Microsoft.IdentityModel.Tokens/RsaCryptoServiceProviderProxy.cs +++ b/src/Microsoft.IdentityModel.Tokens/RsaCryptoServiceProviderProxy.cs @@ -171,6 +171,16 @@ public byte[] SignData(byte[] input, object hash) return _rsa.SignData(input, hash); } + internal byte[] SignData(byte[] input, int offset, int length, object hash) + { + if (input == null || input.Length == 0) + throw LogHelper.LogArgumentNullException(nameof(input)); + + _ = hash ?? throw LogHelper.LogArgumentNullException(nameof(hash)); + + return _rsa.SignData(input, offset, length, hash); + } + /// /// Verifies that a digital signature is valid by determining the hash value in the signature using the provided public key and comparing it to the hash value of the provided data. /// diff --git a/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs b/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs index 4ddee4e2dd..8491b3b28d 100644 --- a/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs +++ b/src/Microsoft.IdentityModel.Tokens/SignatureProvider.cs @@ -95,6 +95,17 @@ internal int Release() /// signed bytes public abstract byte[] Sign(byte[] input); + internal virtual byte[] Sign(byte[] input, int offset, int count) + { + throw LogHelper.LogExceptionMessage(new NotImplementedException()); + } + +#if NET6_0_OR_GREATER + internal virtual bool Sign(ReadOnlySpan data, Span destination, out int bytesWritten) + { + throw LogHelper.LogExceptionMessage(new NotImplementedException()); + } +#endif /// Verifies that the over using the /// and specified by this /// are consistent. diff --git a/src/Microsoft.IdentityModel.Tokens/SupportedAlgorithms.cs b/src/Microsoft.IdentityModel.Tokens/SupportedAlgorithms.cs index a44984b93a..0735b9dbb8 100644 --- a/src/Microsoft.IdentityModel.Tokens/SupportedAlgorithms.cs +++ b/src/Microsoft.IdentityModel.Tokens/SupportedAlgorithms.cs @@ -358,5 +358,46 @@ internal static bool IsSupportedSymmetricAlgorithm(string algorithm) || SymmetricKeyWrapAlgorithms.Contains(algorithm) || SymmetricSigningAlgorithms.Contains(algorithm); } + + /// + /// Returns the maximum size in bytes for a supported signature algorithms. + /// The key size affects the signature size for asymmetric algorithms. + /// + /// + /// Set size for known algorithms, 2K default. + internal static int GetMaxByteCount(string algorithm) => algorithm switch + { + SecurityAlgorithms.HmacSha256 or + SecurityAlgorithms.HmacSha256Signature => 32, + + SecurityAlgorithms.HmacSha384 or + SecurityAlgorithms.HmacSha384Signature => 48, + + SecurityAlgorithms.HmacSha512 or + SecurityAlgorithms.HmacSha512Signature => 64, + + SecurityAlgorithms.EcdsaSha256 or + SecurityAlgorithms.EcdsaSha256Signature or + SecurityAlgorithms.EcdsaSha384 or + SecurityAlgorithms.EcdsaSha384Signature or + SecurityAlgorithms.RsaSha256 or + SecurityAlgorithms.RsaSha256Signature or + SecurityAlgorithms.RsaSsaPssSha256 or + SecurityAlgorithms.RsaSsaPssSha256Signature or + SecurityAlgorithms.RsaSha384 or + SecurityAlgorithms.RsaSsaPssSha384 or + SecurityAlgorithms.RsaSsaPssSha384Signature or + SecurityAlgorithms.RsaSha384Signature => 512, + + SecurityAlgorithms.EcdsaSha512 or + SecurityAlgorithms.EcdsaSha512Signature or + SecurityAlgorithms.RsaSha512 or + SecurityAlgorithms.RsaSsaPssSha512 or + SecurityAlgorithms.RsaSsaPssSha512Signature or + SecurityAlgorithms.RsaSha512Signature => 1024, + + // if we don't know the algorithm, report 2K twice as big as any known algorithm. + _ => 2048, + }; } } diff --git a/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs b/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs index e65cd49611..ac83fee782 100644 --- a/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs +++ b/src/Microsoft.IdentityModel.Tokens/SymmetricSignatureProvider.cs @@ -22,7 +22,7 @@ public class SymmetricSignatureProvider : SignatureProvider /// /// Mapping from algorithm to the expected signature size in bytes. /// - private static readonly Dictionary _expectedSignatureSizeInBytes = new Dictionary + internal static readonly Dictionary ExpectedSignatureSizeInBytes = new Dictionary { { SecurityAlgorithms.HmacSha256, 32 }, { SecurityAlgorithms.HmacSha256Signature, 32 }, @@ -203,6 +203,71 @@ public override byte[] Sign(byte[] input) } } +#if NET6_0_OR_GREATER + internal override bool Sign(ReadOnlySpan input, Span signature, out int bytesWritten) + { + if (input == null || input.Length == 0) + throw LogHelper.LogArgumentNullException(nameof(input)); + + if (_disposed) + { + CryptoProviderCache?.TryRemove(this); + throw LogHelper.LogExceptionMessage(new ObjectDisposedException(GetType().ToString())); + } + + KeyedHashAlgorithm keyedHashAlgorithm = GetKeyedHashAlgorithm(GetKeyBytes(Key), Algorithm); + + try + { + return keyedHashAlgorithm.TryComputeHash(input, signature, out bytesWritten); + } + catch + { + CryptoProviderCache?.TryRemove(this); + Dispose(true); + throw; + } + finally + { + if (!_disposed) + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + } + } +#endif + + internal override byte[] Sign(byte[] input, int offset, int count) + { + if (input == null || input.Length == 0) + throw LogHelper.LogArgumentNullException(nameof(input)); + + if (_disposed) + { + CryptoProviderCache?.TryRemove(this); + throw LogHelper.LogExceptionMessage(new ObjectDisposedException(GetType().ToString())); + } + + if (LogHelper.IsEnabled(EventLogLevel.Informational)) + LogHelper.LogInformation(LogMessages.IDX10642, input); + + KeyedHashAlgorithm keyedHashAlgorithm = GetKeyedHashAlgorithm(GetKeyBytes(Key), Algorithm); + + try + { + return keyedHashAlgorithm.ComputeHash(input, offset, count); + } + catch + { + CryptoProviderCache?.TryRemove(this); + Dispose(true); + throw; + } + finally + { + if (!_disposed) + ReleaseKeyedHashAlgorithm(keyedHashAlgorithm); + } + } + /// /// Verifies that a signature created over the 'input' matches the signature. Using and 'algorithm' passed to . /// @@ -362,7 +427,7 @@ internal bool Verify(byte[] input, int inputOffset, int inputLength, byte[] sign // Check that signature length matches algorithm. // If we don't have an entry for the algorithm in our dictionary, that is probably a bug. // This is why a new message was created, rather than using IDX10640. - if (!_expectedSignatureSizeInBytes.TryGetValue(algorithmToValidate, out int expectedSignatureLength)) + if (!ExpectedSignatureSizeInBytes.TryGetValue(algorithmToValidate, out int expectedSignatureLength)) throw LogHelper.LogExceptionMessage(new ArgumentException( LogHelper.FormatInvariant( LogMessages.IDX10718, @@ -413,7 +478,6 @@ internal bool Verify(byte[] input, int inputOffset, int inputLength, byte[] sign } } - #region IDisposable Members /// diff --git a/src/Microsoft.IdentityModel.Tokens/TokenUtilities.cs b/src/Microsoft.IdentityModel.Tokens/TokenUtilities.cs index a6ae40965b..e79df376da 100644 --- a/src/Microsoft.IdentityModel.Tokens/TokenUtilities.cs +++ b/src/Microsoft.IdentityModel.Tokens/TokenUtilities.cs @@ -83,7 +83,7 @@ internal static Dictionary CreateDictionaryFromClaims(IEnumerabl return payload; } - internal static IDictionary CreateDictionaryFromClaims( + internal static Dictionary CreateDictionaryFromClaims( IEnumerable claims, SecurityTokenDescriptor tokenDescriptor, bool audienceSet, diff --git a/test/Microsoft.IdentityModel.JsonWebTokens.Tests/JsonWebTokenHandlerTests.cs b/test/Microsoft.IdentityModel.JsonWebTokens.Tests/JsonWebTokenHandlerTests.cs index d178822941..21ee278858 100644 --- a/test/Microsoft.IdentityModel.JsonWebTokens.Tests/JsonWebTokenHandlerTests.cs +++ b/test/Microsoft.IdentityModel.JsonWebTokens.Tests/JsonWebTokenHandlerTests.cs @@ -1283,6 +1283,154 @@ public static TheoryData CreateJWSTheoryData } } + // This test checks to make sure that SecurityTokenDescriptor.Audience, Expires, IssuedAt, NotBefore, Issuer have priority over SecurityTokenDescriptor.Claims. + [Theory, MemberData(nameof(CreateJWSWithSecurityTokenDescriptorClaimsTheoryData))] + public void CreateJWSWithSecurityTokenDescriptorClaims(CreateTokenTheoryData theoryData) + { + var context = TestUtilities.WriteHeader($"{this}.CreateJWSWithSecurityTokenDescriptorClaims", theoryData); + + var jwtToken = new JsonWebTokenHandler().CreateToken(theoryData.TokenDescriptor); + JsonWebToken jsonWebToken = new JsonWebToken(jwtToken); + + jsonWebToken.TryGetPayloadValue("iss", out string issuer); + IdentityComparer.AreEqual(theoryData.ExpectedClaims["iss"], issuer, context); + + jsonWebToken.TryGetPayloadValue("aud", out string audience); + IdentityComparer.AreEqual(theoryData.ExpectedClaims["aud"], audience, context); + + jsonWebToken.TryGetPayloadValue("exp", out long exp); + IdentityComparer.AreEqual(theoryData.ExpectedClaims["exp"], exp, context); + + jsonWebToken.TryGetPayloadValue("iat", out long iat); + IdentityComparer.AreEqual(theoryData.ExpectedClaims["iat"], iat, context); + + jsonWebToken.TryGetPayloadValue("nbf", out long nbf); + IdentityComparer.AreEqual(theoryData.ExpectedClaims["nbf"], nbf, context); + + TestUtilities.AssertFailIfErrors(context); + } + + public static TheoryData CreateJWSWithSecurityTokenDescriptorClaimsTheoryData + { + get + { + TheoryData theoryData = new TheoryData(); + + SigningCredentials signingCredentials = new SigningCredentials(KeyingMaterial.DefaultSymmetricSecurityKey_256, SecurityAlgorithms.HmacSha256, SecurityAlgorithms.Sha256); + + DateTime iat = DateTime.UtcNow; + DateTime exp = iat + TimeSpan.FromDays(1); + DateTime nbf = iat + TimeSpan.FromMinutes(1); + string iss = Guid.NewGuid().ToString(); + string aud = Guid.NewGuid().ToString(); + + Dictionary claims = new Dictionary() + { + { JwtRegisteredClaimNames.Aud, aud }, + { JwtRegisteredClaimNames.Exp, EpochTime.GetIntDate(exp) }, + { JwtRegisteredClaimNames.Iat, EpochTime.GetIntDate(iat) }, + { JwtRegisteredClaimNames.Iss, iss}, + { JwtRegisteredClaimNames.Nbf, EpochTime.GetIntDate(nbf) } + }; + + // These values will be set on the SecurityTokenDescriptor + DateTime iatSTD = DateTime.UtcNow + TimeSpan.FromHours(1); + DateTime expSTD = iat + TimeSpan.FromDays(1); + DateTime nbfSTD = iat + TimeSpan.FromMinutes(1); + string issSTD = Guid.NewGuid().ToString(); + string audSTD = Guid.NewGuid().ToString(); + + theoryData.Add(new CreateTokenTheoryData("ValuesFromClaims") + { + TokenDescriptor = new SecurityTokenDescriptor + { + SigningCredentials = signingCredentials, + Claims = claims + }, + ExpectedClaims = claims + }); + + theoryData.Add(new CreateTokenTheoryData("AllValuesFromSTD") + { + TokenDescriptor = new SecurityTokenDescriptor + { + SigningCredentials = signingCredentials, + Claims = claims, + IssuedAt = iatSTD, + Expires = expSTD, + NotBefore = nbfSTD, + Audience = audSTD, + Issuer = issSTD + }, + ExpectedClaims = new Dictionary() + { + { JwtRegisteredClaimNames.Aud, audSTD }, + { JwtRegisteredClaimNames.Exp, EpochTime.GetIntDate(expSTD) }, + { JwtRegisteredClaimNames.Iat, EpochTime.GetIntDate(iatSTD) }, + { JwtRegisteredClaimNames.Iss, issSTD}, + { JwtRegisteredClaimNames.Nbf, EpochTime.GetIntDate(nbfSTD) } + } + }); + + theoryData.Add(new CreateTokenTheoryData("ExpFromSTD") + { + TokenDescriptor = new SecurityTokenDescriptor + { + SigningCredentials = signingCredentials, + Claims = claims, + Expires = expSTD + }, + ExpectedClaims = new Dictionary() + { + { JwtRegisteredClaimNames.Aud, aud }, + { JwtRegisteredClaimNames.Exp, EpochTime.GetIntDate(expSTD) }, + { JwtRegisteredClaimNames.Iat, EpochTime.GetIntDate(iat) }, + { JwtRegisteredClaimNames.Iss, iss}, + { JwtRegisteredClaimNames.Nbf, EpochTime.GetIntDate(nbf) } + } + }); + + theoryData.Add(new CreateTokenTheoryData("IatFromSTD") + { + TokenDescriptor = new SecurityTokenDescriptor + { + SigningCredentials = signingCredentials, + Claims = claims, + IssuedAt = iatSTD + }, + ExpectedClaims = new Dictionary() + { + { JwtRegisteredClaimNames.Aud, aud }, + { JwtRegisteredClaimNames.Exp, EpochTime.GetIntDate(exp) }, + { JwtRegisteredClaimNames.Iat, EpochTime.GetIntDate(iatSTD) }, + { JwtRegisteredClaimNames.Iss, iss}, + { JwtRegisteredClaimNames.Nbf, EpochTime.GetIntDate(nbf) } + } + }); + + theoryData.Add(new CreateTokenTheoryData("NbfFromSTD") + { + TokenDescriptor = new SecurityTokenDescriptor + { + SigningCredentials = signingCredentials, + Claims = claims, + NotBefore = nbfSTD + }, + ExpectedClaims = new Dictionary() + { + { JwtRegisteredClaimNames.Aud, aud }, + { JwtRegisteredClaimNames.Exp, EpochTime.GetIntDate(exp) }, + { JwtRegisteredClaimNames.Iat, EpochTime.GetIntDate(iat) }, + { JwtRegisteredClaimNames.Iss, iss}, + { JwtRegisteredClaimNames.Nbf, EpochTime.GetIntDate(nbfSTD) } + } + }); + + return theoryData; + } + } + + // This test checks to make sure that additional header claims are added as expected to the JWT token header. [Theory, MemberData(nameof(CreateJWSWithAdditionalHeaderClaimsTheoryData))] public void CreateJWSWithAdditionalHeaderClaims(CreateTokenTheoryData theoryData) @@ -3880,9 +4028,8 @@ public CreateTokenTheoryData() { } - public CreateTokenTheoryData(string testId) + public CreateTokenTheoryData(string testId) : base(testId) { - TestId = testId; } public Dictionary AdditionalHeaderClaims { get; set; } @@ -3916,6 +4063,8 @@ public CreateTokenTheoryData(string testId) public string Algorithm { get; set; } public IEnumerable ExpectedDecryptionKeys { get; set; } + + public Dictionary ExpectedClaims { get; set; } } // Overrides CryptoProviderFactory.CreateAuthenticatedEncryptionProvider to create AuthenticatedEncryptionProviderMock that provides AesGcm encryption. diff --git a/test/Microsoft.IdentityModel.TestUtils/IdentityComparer.cs b/test/Microsoft.IdentityModel.TestUtils/IdentityComparer.cs index e41402809f..97725342e1 100644 --- a/test/Microsoft.IdentityModel.TestUtils/IdentityComparer.cs +++ b/test/Microsoft.IdentityModel.TestUtils/IdentityComparer.cs @@ -58,6 +58,7 @@ public class IdentityComparer { typeof(IEnumerable).ToString(), AreSecurityKeyEnumsEqual }, { typeof(IEnumerable).ToString(), AreStringEnumsEqual }, { typeof(IEnumerable).ToString(), AreX509DataEnumsEqual }, + { typeof(int).ToString(), AreIntsEqual }, { typeof(IssuerSerial).ToString(), CompareAllPublicProperties }, { typeof(JArray).ToString(), AreJArraysEqual }, { typeof(JObject).ToString(), AreJObjectsEqual }, @@ -82,6 +83,7 @@ public class IdentityComparer { typeof(List).ToString(), AreSecurityKeyEnumsEqual }, { typeof(List).ToString(), AreReferenceEnumsEqual }, { typeof(List).ToString(), AreUriEnumsEqual }, + { typeof(long).ToString(), AreLongsEqual }, { typeof(OpenIdConnectConfiguration).ToString(), CompareAllPublicProperties }, { typeof(OpenIdConnectMessage).ToString(), CompareAllPublicProperties }, { typeof(Reference).ToString(), CompareAllPublicProperties }, @@ -152,6 +154,11 @@ public class IdentityComparer // Keep methods in alphabetical order public static bool AreBoolsEqual(object object1, object object2, CompareContext context) + { + return AreBoolsEqual(object1, object2, "bool1", "bool2", context); + } + + public static bool AreBoolsEqual(object object1, object object2, string name1, string name2, CompareContext context) { var localContext = new CompareContext(context); if (!ContinueCheckingEquality(object1, object2, localContext)) @@ -165,6 +172,7 @@ public static bool AreBoolsEqual(object object1, object object2, CompareContext if (bool1 != bool2) { + localContext.Diffs.Add($"{name1} != {name2}"); localContext.Diffs.Add($"'{bool1}'"); localContext.Diffs.Add($"!="); localContext.Diffs.Add($"'{bool2}'"); @@ -174,6 +182,11 @@ public static bool AreBoolsEqual(object object1, object object2, CompareContext } public static bool AreBytesEqual(object object1, object object2, CompareContext context) + { + return AreBytesEqual(object1, object2, "bytes1", "bytes2", context); + } + + public static bool AreBytesEqual(object object1, object object2, string name1, string name2, CompareContext context) { var localContext = new CompareContext(context); if (!ContinueCheckingEquality(object1, object2, localContext)) @@ -181,17 +194,24 @@ public static bool AreBytesEqual(object object1, object object2, CompareContext var bytes1 = (byte[])object1; var bytes2 = (byte[])object2; - if (bytes1.Length != bytes2.Length) { + localContext.Diffs.Add($"{name1} != {name2}"); localContext.Diffs.Add("(bytes1.Length != bytes2.Length)"); } else { + bool firstDiff = true; for (int i = 0; i < bytes1.Length; i++) { if (bytes1[i] != bytes2[i]) { + if (firstDiff) + { + firstDiff = false; + localContext.Diffs.Add($"{name1} != {name2}"); + } + localContext.Diffs.Add($"'{bytes1}'"); localContext.Diffs.Add("!="); localContext.Diffs.Add($"'{bytes2}'"); @@ -696,6 +716,29 @@ public static bool AreJwtSecurityTokensEqual(JwtSecurityToken jwt1, JwtSecurityT return context.Merge(localContext); } + public static bool AreIntsEqual(object object1, object object2, CompareContext context) + { + return AreIntsEqual((int)object1, (int)object2, "int1", "int2", context); + } + + public static bool AreIntsEqual(int int1, int int2, string name1, string name2, CompareContext context) + { + var localContext = new CompareContext(context); + + if (int1 == int2) + return true; + + if (int1 != int2) + { + localContext.Diffs.Add($"{name1} != {name2}"); + localContext.Diffs.Add($"'{int1}'"); + localContext.Diffs.Add($"!="); + localContext.Diffs.Add($"'{int2}'"); + } + + return context.Merge(localContext); + } + public static bool AreKeyInfosEqual(KeyInfo keyInfo1, KeyInfo keyInfo2, CompareContext context) { var localContext = new CompareContext(context); @@ -710,6 +753,34 @@ public static bool AreKeyInfoEnumsEqual(object object1, object object2, CompareC return AreEnumsEqual(object1 as IEnumerable, object2 as IEnumerable, context, AreEqual); } + public static bool AreLongsEqual(object object1, object object2, CompareContext context) + { + return AreLongsEqual(object1, object2, "long1", "long2", context); + } + + public static bool AreLongsEqual(object object1, object object2, string name1, string name2, CompareContext context) + { + var localContext = new CompareContext(context); + if (!ContinueCheckingEquality(object1, object2, localContext)) + return context.Merge(localContext); + + long long1 = (long)object1; + long long2 = Convert.ToInt64(Convert.ToDouble(object2)); + + if (long1 == long2) + return true; + + if (long1 != long2) + { + localContext.Diffs.Add($"{name1} != {name2}"); + localContext.Diffs.Add($"'{long1}'"); + localContext.Diffs.Add($"!="); + localContext.Diffs.Add($"'{long2}'"); + } + + return context.Merge(localContext); + } + public static bool AreObjectDictionariesEqual(Object object1, Object object2, CompareContext context) { var localContext = new CompareContext(context); @@ -962,6 +1033,11 @@ public static bool AreStringDictionariesEqual(Object object1, Object object2, Co } public static bool AreStringsEqual(object object1, object object2, CompareContext context) + { + return AreStringsEqual(object1, object2, "str1", "str2", context); + } + + public static bool AreStringsEqual(object object1, object object2, string name1, string name2, CompareContext context) { var localContext = new CompareContext(context); if (!ContinueCheckingEquality(object1, object2, localContext)) @@ -976,13 +1052,17 @@ public static bool AreStringsEqual(object object1, object object2, CompareContex if (ReferenceEquals(str1, str2)) return true; - if (str1 == null || str2 == null) - localContext.Diffs.Add("(str1 == null || str2 == null)"); + if (str1 == null) + localContext.Diffs.Add($"({name1} == null, {name2} == {str2}."); + + if(str2 == null) + localContext.Diffs.Add($"({name1} == {str1}, {name2} == null."); if (!string.Equals(str1, str2, context.StringComparison)) { - localContext.Diffs.Add($"str1 != str2, StringComparison: '{context.StringComparison}'"); + localContext.Diffs.Add($"{name1} != {name2}, StringComparison: '{context.StringComparison}'"); localContext.Diffs.Add(str1); + localContext.Diffs.Add($"!="); localContext.Diffs.Add(str2); } diff --git a/test/Microsoft.IdentityModel.Tokens.Tests/Base64UrlEncodingTests.cs b/test/Microsoft.IdentityModel.Tokens.Tests/Base64UrlEncodingTests.cs new file mode 100644 index 0000000000..2d95e766c5 --- /dev/null +++ b/test/Microsoft.IdentityModel.Tokens.Tests/Base64UrlEncodingTests.cs @@ -0,0 +1,311 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Text; +using Microsoft.IdentityModel.TestUtils; +using Xunit; + +namespace Microsoft.IdentityModel.Tokens.UrlEncoding.Tests +{ + public class Base64UrlEncoderTests + { + [Theory, MemberData(nameof(EncodeTestCases), DisableDiscoveryEnumeration = true)] + public void EncodeTests(Base64UrlEncoderTheoryData theoryData) + { + var context = TestUtilities.WriteHeader("EncodeTests", theoryData); + string encoderString = null; + string encoderBytes = null; + string encoderBytesUsingOffset = null; + string encoderBytesUsingSpan = null; + + try + { + // to get to error code in Base64UrlEncoding, we need to skip Encoder + if (!theoryData.EncodingOnly) + { + encoderString = Base64UrlEncoder.Encode(theoryData.Json); + encoderBytes = Base64UrlEncoder.Encode(theoryData.Bytes); + encoderBytesUsingOffset = Base64UrlEncoder.Encode(theoryData.OffsetBytes, theoryData.Offset, theoryData.Length); + int encodedCharsCount = Base64UrlEncoder.Encode(theoryData.OffsetBytes.AsSpan().Slice(theoryData.Offset, theoryData.Length), theoryData.Chars.AsSpan()); + encoderBytesUsingSpan = new string(theoryData.Chars, 0, encodedCharsCount); + } + + string encodingString = Base64UrlEncoding.Encode(theoryData.Bytes); + string encodingBytesUsingOffset = Base64UrlEncoding.Encode(theoryData.OffsetBytes, theoryData.Offset, theoryData.Length); + + theoryData.ExpectedException.ProcessNoException(context); + + if (!theoryData.EncodingOnly) + { + IdentityComparer.AreStringsEqual(encoderString, encoderBytes, "encoderString", "encoderBytes", context); + IdentityComparer.AreStringsEqual(encoderBytesUsingOffset, encoderBytes, "encoderBytesUsingOffset", "encoderBytes", context); + IdentityComparer.AreStringsEqual(encoderBytesUsingSpan, encoderBytes, "encoderBytesUsingSpan", "encoderBytes", context); + IdentityComparer.AreStringsEqual(encodingString, encoderBytes, "encodingString", "encoderBytes", context); + } + + IdentityComparer.AreStringsEqual(encodingBytesUsingOffset, encodingString, "encodingBytesUsingOffset", "encodingString", context); + IdentityComparer.AreStringsEqual(theoryData.ExpectedValue, encodingString, "theoryData.ExpectedValue", "encodingString", context); + + } + catch (Exception ex) + { + theoryData.ExpectedException.ProcessException(ex, context); + } + + TestUtilities.AssertFailIfErrors(context); + } + + public static TheoryData EncodeTestCases + { + get + { + TheoryData theoryData = new TheoryData(); + + // These values are sourced from https://datatracker.ietf.org/doc/html/rfc7519#section-6.1 + string json = $@"{{""alg"":""none""}}"; + string expectedValue = "eyJhbGciOiJub25lIn0"; + byte[] utf8Bytes = Encoding.UTF8.GetBytes(json); + + theoryData.Add(new Base64UrlEncoderTheoryData("Header_Offset_0") + { + Bytes = utf8Bytes, + Chars = new char[1024], + ExpectedValue = expectedValue, + Json = json, + Length = utf8Bytes.Length, + Offset = 0, + OffsetBytes = utf8Bytes, + OffsetLength = utf8Bytes.Length + }); + + // NOTE the spec performs the encoding over the \r\n and space ' '. + json = "{\"iss\":\"joe\",\r\n \"exp\":1300819380,\r\n \"http://example.com/is_root\":true}"; + expectedValue = "eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ"; + utf8Bytes = Encoding.UTF8.GetBytes(json); + + theoryData.Add(new Base64UrlEncoderTheoryData("Payload_Offset_0") + { + Bytes = utf8Bytes, + Chars = new char[1024], + ExpectedValue = expectedValue, + Json = json, + Length = utf8Bytes.Length, + Offset = 0, + OffsetBytes = utf8Bytes, + OffsetLength = utf8Bytes.Length + }); + + byte[] utf8BytesOffset = new byte[utf8Bytes.Length * 2]; + int count = Encoding.UTF8.GetBytes(json, 0, json.Length, utf8BytesOffset, 5); + theoryData.Add(new Base64UrlEncoderTheoryData("Payload_Offset_5") + { + Bytes = utf8Bytes, + Chars = new char[1024], + ExpectedValue = expectedValue, + Json = json, + Length = count, + Offset = 5, + OffsetBytes = utf8BytesOffset, + OffsetLength = count + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("JsonNULL") + { + Bytes = utf8Bytes, + Chars = new char[1024], + ExpectedException = ExpectedException.ArgumentNullException("IDX10000:"), + ExpectedValue = expectedValue, + Json = null, + Length = count, + Offset = 5, + OffsetBytes = utf8BytesOffset, + OffsetLength = count + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("BytesNULL") + { + Bytes = null, + Chars = new char[1024], + ExpectedException = ExpectedException.ArgumentNullException("IDX10000:"), + ExpectedValue = expectedValue, + Json = json, + Length = count, + Offset = 5, + OffsetBytes = utf8BytesOffset, + OffsetLength = count + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("OffsetBytesNULL") + { + Bytes = utf8Bytes, + Chars = new char[1024], + ExpectedException = ExpectedException.ArgumentNullException("IDX10000:"), + ExpectedValue = expectedValue, + Json = json, + Length = count, + Offset = 5, + OffsetBytes = null, + OffsetLength = count + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("Length_Negative") + { + Bytes = utf8Bytes, + Chars = new char[1024], + ExpectedException = ExpectedException.ArgumentOutOfRangeException("IDX10716:"), + ExpectedValue = expectedValue, + Json = json, + Length = -1, + Offset = 5, + OffsetBytes = utf8BytesOffset, + OffsetLength = 5 + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("Length_Zero") + { + Bytes = new byte[0], + Chars = new char[1024], + ExpectedValue = string.Empty, + Json = string.Empty, + Length = 0, + Offset = 0, + OffsetBytes = new byte[0], + OffsetLength = 0 + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("Bytes_Zero") + { + Bytes = new byte[0], + Chars = new char[1024], + ExpectedValue = string.Empty, + Json = string.Empty, + Length = 0, + Offset = 0, + OffsetBytes = new byte[0], + OffsetLength = 0 + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("Input_LessThan_Offset_Length") + { + Bytes = utf8Bytes, + Chars = new char[1024], + ExpectedException = ExpectedException.ArgumentOutOfRangeException("IDX10717:"), + ExpectedValue = expectedValue, + Json = json, + Length = count, + Offset = utf8BytesOffset.Length, + OffsetBytes = utf8BytesOffset, + OffsetLength = utf8BytesOffset.Length + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("BytesNULL_Encoding") + { + Bytes = null, + Chars = new char[1024], + ExpectedException = ExpectedException.ArgumentNullException("IDX10000:"), + ExpectedValue = expectedValue, + Length = count, + Json = json, + Offset = 5, + OffsetBytes = utf8BytesOffset, + OffsetLength = count, + EncodingOnly = true + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("OffsetBytesNULL_Encoding") + { + Bytes = utf8Bytes, + Chars = new char[1024], + ExpectedException = ExpectedException.ArgumentNullException("IDX10000:"), + ExpectedValue = expectedValue, + Json = json, + Length = count, + Offset = 5, + OffsetBytes = null, + OffsetLength = count, + EncodingOnly = true + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("Offset_Negative_Encoding") + { + Bytes = utf8Bytes, + Chars = new char[1024], + ExpectedException = ExpectedException.ArgumentOutOfRangeException("IDX10716:"), + ExpectedValue = expectedValue, + Json = json, + Length = count, + Offset = -1, + OffsetBytes = utf8BytesOffset, + OffsetLength = count, + EncodingOnly = true + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("Length_Negative_Encoding") + { + Bytes = utf8Bytes, + Chars = new char[1024], + ExpectedException = ExpectedException.ArgumentOutOfRangeException("IDX10716:"), + ExpectedValue = expectedValue, + Json = json, + Length = -1, + Offset = 5, + OffsetBytes = utf8BytesOffset, + OffsetLength = -1, + EncodingOnly = true + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("Length_Zero_Encoding") + { + Bytes = new byte[0], + Chars = new char[1024], + ExpectedValue = string.Empty, + Json = string.Empty, + Length = 0, + Offset = 0, + OffsetBytes = new byte[0], + OffsetLength = 0 + }); + + theoryData.Add(new Base64UrlEncoderTheoryData("Bytes_Zero_Encoding") + { + Bytes = new byte[0], + Chars = new char[1024], + ExpectedValue = string.Empty, + Json = string.Empty, + Length = 0, + Offset = 0, + OffsetBytes = new byte[0], + OffsetLength = 0, + EncodingOnly = true + }); + + + return theoryData; + } + } + + public class Base64UrlEncoderTheoryData : TheoryDataBase + { + public Base64UrlEncoderTheoryData(string testId) : base(testId) { } + + public byte[] Bytes { get; set; } + + public char[] Chars { get; set; } + + public string ExpectedValue { get; set; } + + public string Json { get; set; } + + public int Length { get; set; } + + public int Offset { get; set; } + + public byte[] OffsetBytes { get; set; } + + public int OffsetLength { get; set; } + + public bool EncodingOnly { get; set; } = false; + } + } +} diff --git a/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs b/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs index b43d41388c..ec42035a02 100644 --- a/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs +++ b/test/Microsoft.IdentityModel.Tokens.Tests/CryptoProviderFactoryTests.cs @@ -315,7 +315,7 @@ public void FaultingSymmetricSignatureProviders(SignatureProviderTheoryData theo var signingSignatureProvider = theoryData.CryptoProviderFactory.CreateForSigning(theoryData.SigningKey, theoryData.SigningAlgorithm) as SymmetricSignatureProvider; var signedBytes = signingSignatureProvider.Sign(bytes); var verifyingSignatureProvider = theoryData.CryptoProviderFactory.CreateForVerifying(theoryData.VerifyKey, theoryData.VerifyAlgorithm) as SymmetricSignatureProvider; - if (theoryData.VerifySpecifyingLength) + if (theoryData.VerifyUsingLength) verifyingSignatureProvider.Verify(bytes, signedBytes); else verifyingSignatureProvider.Verify(bytes, signedBytes, signedBytes.Length); @@ -426,7 +426,7 @@ public static TheoryData FaultingSymmetricSignature VerifyAlgorithm = ALG.HmacSha256, VerifyKey = Default.SymmetricSigningKey256, VerifySignatureProviderType = typeof(CustomSymmetricSignatureProvider).ToString(), - VerifySpecifyingLength = true + VerifyUsingLength = true }); // Symmetric disposed signing @@ -496,7 +496,7 @@ public static TheoryData FaultingSymmetricSignature VerifyAlgorithm = ALG.HmacSha256, VerifyKey = Default.SymmetricSigningKey256, VerifySignatureProviderType = typeof(CustomSymmetricSignatureProvider).ToString(), - VerifySpecifyingLength = true + VerifyUsingLength = true }); // Symmetric signing verifying succeed @@ -518,7 +518,7 @@ public static TheoryData FaultingSymmetricSignature VerifyAlgorithm = ALG.HmacSha256, VerifyKey = Default.SymmetricSigningKey256, VerifySignatureProviderType = typeof(CustomSymmetricSignatureProvider).ToString(), - VerifySpecifyingLength = false + VerifyUsingLength = false }); // Symmetric signing verifying (specifying length) succeed @@ -540,7 +540,7 @@ public static TheoryData FaultingSymmetricSignature VerifyAlgorithm = ALG.HmacSha256, VerifyKey = Default.SymmetricSigningKey256, VerifySignatureProviderType = typeof(CustomSymmetricSignatureProvider).ToString(), - VerifySpecifyingLength = true + VerifyUsingLength = true }); return theoryData; diff --git a/test/Microsoft.IdentityModel.Tokens.Tests/IdentityComparerTests.cs b/test/Microsoft.IdentityModel.Tokens.Tests/IdentityComparerTests.cs index 63c79da8b6..e78a0391b6 100644 --- a/test/Microsoft.IdentityModel.Tokens.Tests/IdentityComparerTests.cs +++ b/test/Microsoft.IdentityModel.Tokens.Tests/IdentityComparerTests.cs @@ -577,7 +577,7 @@ public void CompareStrings() Assert.True(context.Diffs.Count(s => s == "str1 != str2, StringComparison: 'Ordinal'") == 1); Assert.True(context.Diffs[1] == string1); - Assert.True(context.Diffs[2] == string2); + Assert.True(context.Diffs[3] == string2); } [Fact] diff --git a/test/Microsoft.IdentityModel.Tokens.Tests/SignatureProviderTests.cs b/test/Microsoft.IdentityModel.Tokens.Tests/SignatureProviderTests.cs index aea69ace06..d415d2120e 100644 --- a/test/Microsoft.IdentityModel.Tokens.Tests/SignatureProviderTests.cs +++ b/test/Microsoft.IdentityModel.Tokens.Tests/SignatureProviderTests.cs @@ -3,11 +3,10 @@ using System; using System.Collections.Generic; -using System.Reflection; using System.Runtime.InteropServices; using System.Security.Cryptography; using System.Text; -using System.Threading.Tasks; +using Microsoft.Azure.KeyVault.Cryptography; using Microsoft.IdentityModel.TestUtils; using Xunit; @@ -15,8 +14,6 @@ using EE = Microsoft.IdentityModel.TestUtils.ExpectedException; using KEY = Microsoft.IdentityModel.TestUtils.KeyingMaterial; -#pragma warning disable CS3016 // Arrays as attribute arguments is not CLS-compliant - namespace Microsoft.IdentityModel.Tokens.Tests { /// @@ -1039,6 +1036,367 @@ public static TheoryData SignatureTheoryData() return theoryData; } + + /// + /// Tests that the signature size returned from TokenUtilities.GetSignatureSize(string algorithm) ia not too small. + /// Each supported signature is tried, 2k is the default. + /// + /// + [Theory, MemberData(nameof(MaximumSignatureSizeTestCases), DisableDiscoveryEnumeration = true)] + public void MaximumSignatureSizeTests(SignTheoryData theoryData) + { + var context = TestUtilities.WriteHeader($"MaximumSignatureSizeTests", theoryData); + + try + { + byte[] signature = theoryData.SignatureProvider.Sign(theoryData.Bytes); + int maximumSignatureSize = SupportedAlgorithms.GetMaxByteCount(theoryData.SignatureProvider.Algorithm); + if (signature.Length > maximumSignatureSize) + context.AddDiff($"signature.Length: '{signature.Length}' > maximumSignatureSize: '{maximumSignatureSize}'."); + + theoryData.ExpectedException.ProcessNoException(context); + } + catch (Exception ex) + { + theoryData.ExpectedException.ProcessException(ex, context); + } + + TestUtilities.AssertFailIfErrors(context); + } + + public static TheoryData MaximumSignatureSizeTestCases + { + get + { + var theoryData = new TheoryData(); + + AddSymmetricKeySizes(KeyingMaterial.DefaultSymmetricSecurityKey_256, theoryData); + AddSymmetricKeySizes(KeyingMaterial.DefaultSymmetricSecurityKey_384, theoryData); + AddSymmetricKeySizes(KeyingMaterial.DefaultSymmetricSecurityKey_512, theoryData); + + AddECDSAKeySizes(KeyingMaterial.Ecdsa256Key, theoryData); + AddECDSAKeySizes(KeyingMaterial.Ecdsa384Key, theoryData); + AddECDSAKeySizes(KeyingMaterial.Ecdsa521Key, theoryData); + + AddRSAKeySize(KeyingMaterial.RsaSecurityKey_1024, theoryData); + AddRSAKeySize(KeyingMaterial.RsaSecurityKey_2048, theoryData); + AddRSAKeySize(KeyingMaterial.RsaSecurityKey_4096, theoryData); + + theoryData.Add(new SignTheoryData("Custom2K") + { + SignatureProvider = new SignatureProvider2K(KeyingMaterial.RsaSecurityKey_2048, "CustomAlgorithm") + }); + + return theoryData; + } + } + private static void AddECDSAKeySizes(SecurityKey securityKey, TheoryData theoryData) + { + byte[] bytes = Encoding.UTF8.GetBytes(Guid.NewGuid().ToString()); + + foreach (string algorithm in SupportedAlgorithms.EcdsaSigningAlgorithms) + { + if (securityKey.KeySize >= AsymmetricSignatureProvider.DefaultMinimumAsymmetricKeySizeInBitsForSigningMap[algorithm]) + theoryData.Add(new SignTheoryData($"{algorithm}_Key{securityKey.KeySize}") + { + Bytes = bytes, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); + } + } + + private static void AddSymmetricKeySizes(SecurityKey securityKey, TheoryData theoryData) + { + byte[] bytes = Encoding.UTF8.GetBytes(Guid.NewGuid().ToString()); + + foreach (string algorithm in SupportedAlgorithms.SymmetricSigningAlgorithms) + { + if (securityKey.KeySize / 8 >= SymmetricSignatureProvider.ExpectedSignatureSizeInBytes[algorithm]) + theoryData.Add(new SignTheoryData($"{algorithm}_Key{securityKey.KeySize}") + { + Bytes = bytes, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); + } + } + + private static void AddRSAKeySize(SecurityKey securityKey, TheoryData theoryData) + { + byte[] bytes = Encoding.UTF8.GetBytes(Guid.NewGuid().ToString()); + + foreach (string algorithm in SupportedAlgorithms.RsaSigningAlgorithms) + { + if (securityKey.KeySize >= AsymmetricSignatureProvider.DefaultMinimumAsymmetricKeySizeInBitsForSigningMap[algorithm]) + theoryData.Add(new SignTheoryData($"{algorithm}_Key{securityKey.KeySize}") + { + Bytes = bytes, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); + } + + foreach (string algorithm in SupportedAlgorithms.RsaPssSigningAlgorithms) + { + if (securityKey.KeySize >= AsymmetricSignatureProvider.DefaultMinimumAsymmetricKeySizeInBitsForSigningMap[algorithm]) + theoryData.Add(new SignTheoryData($"{algorithm}_Key{securityKey.KeySize}") + { + Bytes = bytes, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); + } + } + +#if NET6_0_OR_GREATER + [Theory, MemberData(nameof(SignUsingSpanTestCases), DisableDiscoveryEnumeration = true)] + public void SignUsingSpanTests(SignTheoryData theoryData) + { + var context = TestUtilities.WriteHeader("SignUsingSpanTests", theoryData); + + try + { + bool success = theoryData.SignatureProvider.Sign(theoryData.Bytes.AsSpan(), theoryData.Buffer.AsSpan(), out int bytesWritten); + + IdentityComparer.AreBoolsEqual(success, theoryData.Success, context); + if (theoryData.Success) + IdentityComparer.AreBoolsEqual(theoryData.SignatureProvider.Verify(theoryData.Bytes, theoryData.Buffer.AsSpan().Slice(0, bytesWritten).ToArray()), true, $"{theoryData.SignatureProvider}", "true", context); + + theoryData.ExpectedException.ProcessNoException(context); + } + catch (Exception ex) + { + theoryData.ExpectedException.ProcessException(ex, context); + } + + TestUtilities.AssertFailIfErrors(context); + } + + public static TheoryData SignUsingSpanTestCases + { + get + { + TheoryData theoryData = new TheoryData(); + byte[] bytes = Encoding.UTF8.GetBytes(Guid.NewGuid().ToString()); + + AddSignUsingSpans(bytes, KeyingMaterial.Ecdsa256Key, SecurityAlgorithms.EcdsaSha256, "ECDSA", theoryData); + AddSignUsingSpans(bytes, KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaSha256, "RSA", theoryData); + AddSignUsingSpans(bytes, new SymmetricSecurityKey(KeyingMaterial.SymmetricKeyBytes2_256), SecurityAlgorithms.HmacSha256, "HMAC256", theoryData); + + theoryData.Add(new SignTheoryData("NotImplementedException") + { + Buffer = new byte[2048], + Bytes = new byte[2048], + Count = 2048, + ExpectedException = new ExpectedException(typeof(NotImplementedException)), + Offset = 0, + SignatureProvider = new SignatureProvider2K(KeyingMaterial.Ecdsa256Key, SecurityAlgorithms.EcdsaSha256) + }); + + return theoryData; + } + } + + internal static void AddSignUsingSpans(byte[] bytes, SecurityKey securityKey, string algorithm, string prefix, TheoryData theoryData) + { + theoryData.Add(new SignTheoryData($"{prefix}_BufferNull") + { + Buffer = null, + Bytes = bytes, + SignatureProvider = CreateProvider(securityKey, algorithm), + Success = false + }); + + theoryData.Add(new SignTheoryData($"{prefix}_BufferOneByte") + { + Buffer = new byte[1], + Bytes = bytes, + SignatureProvider = CreateProvider(securityKey, algorithm), + Success = false + }); + + theoryData.Add(new SignTheoryData($"{prefix}_BufferTooSmall") + { + Buffer = new byte[10], + Bytes = bytes, + SignatureProvider = CreateProvider(securityKey, algorithm), + Success = false + }); + + theoryData.Add(new SignTheoryData($"{prefix}") + { + Buffer = new byte[512], + Bytes = bytes, + SignatureProvider = CreateProvider(securityKey, algorithm), + Success = true + }); + } +#endif + [Theory, MemberData(nameof(SignUsingOffsetTestCases), DisableDiscoveryEnumeration = true)] + public void SignUsingOffsetTests(SignTheoryData theoryData) + { + var context = TestUtilities.WriteHeader("SignUsingOffsetTests", theoryData); + try + { + byte[] signature = theoryData.SignatureProvider.Sign(theoryData.Bytes, theoryData.Offset, theoryData.Count); + if (theoryData.Success) + IdentityComparer.AreBoolsEqual( + theoryData.SignatureProvider.Verify( + theoryData.Bytes.AsSpan().Slice(theoryData.Offset, theoryData.Count).ToArray(), + signature), + true, + $"{theoryData.SignatureProvider}", + "true", + context); + + theoryData.ExpectedException.ProcessNoException(context); + } + catch (Exception ex) + { + theoryData.ExpectedException.ProcessException(ex, context); + } + + TestUtilities.AssertFailIfErrors(context); + } + + public static TheoryData SignUsingOffsetTestCases + { + get + { + TheoryData theoryData = new TheoryData(); + + byte[] bytes = Encoding.UTF8.GetBytes(Guid.NewGuid().ToString()); + AddSignUsingOffsets(bytes, KeyingMaterial.Ecdsa256Key, SecurityAlgorithms.EcdsaSha256, "ECDSA", theoryData); + AddSignUsingOffsets(bytes, KeyingMaterial.RsaSecurityKey_2048, SecurityAlgorithms.RsaSha256, "RSA", theoryData); + AddSignUsingOffsets(bytes, new SymmetricSecurityKey(KeyingMaterial.SymmetricKeyBytes2_256), SecurityAlgorithms.HmacSha256, "HMAC256", theoryData); + + theoryData.Add(new SignTheoryData("NotImplementedException") + { + Bytes = new byte[1024], + Count = 1024, + ExpectedException = new ExpectedException(typeof(NotImplementedException)), + Offset = 0, + SignatureProvider = new SignatureProvider2K(KeyingMaterial.Ecdsa256Key, SecurityAlgorithms.EcdsaSha256) + }); + + return theoryData; + } + } + + internal static void AddSignUsingOffsets(byte[] bytes, SecurityKey securityKey, string algorithm, string prefix, TheoryData theoryData) + { + theoryData.Add(new SignTheoryData($"{prefix}_BytesNull") + { + Bytes = null, + Count = bytes.Length, + ExpectedException = ExpectedException.ArgumentNullException(), + Offset = 0, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); + + theoryData.Add(new SignTheoryData($"{prefix}_BytesEmpty") + { + Bytes = Array.Empty(), + Count = bytes.Length, + ExpectedException = ExpectedException.ArgumentNullException(), + Offset = 0, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); + +#if NET461 || NET462 + // RSA throws a different exception in the following three cases than HMAC or ECDSA 472+ + theoryData.Add(new SignTheoryData($"{prefix}_CountNegative") + { + Bytes = bytes, + Count = -1, + ExpectedException = ExpectedException.ArgumentException(), + Offset = 0, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); + + theoryData.Add(new SignTheoryData($"{prefix}_CountGreaterThanBytes") + { + Bytes = bytes, + Count = bytes.Length + 1, + ExpectedException = ExpectedException.ArgumentException(), + Offset = 0, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); + + theoryData.Add(new SignTheoryData($"{prefix}_CountPlusOffsetGreaterThanBytes") + { + Bytes = bytes, + Count = 10, + ExpectedException = ExpectedException.ArgumentException(), + Offset = bytes.Length - 1, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); +#else + // RSA throws a different exception in the following three cases than HMAC or ECDSA 472+ + theoryData.Add(new SignTheoryData($"{prefix}_CountNegative") + { + Bytes = bytes, + Count = -1, + ExpectedException = prefix == "RSA" ? ExpectedException.ArgumentOutOfRangeException() : ExpectedException.ArgumentException(), + Offset = 0, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); + + theoryData.Add(new SignTheoryData($"{prefix}_CountGreaterThanBytes") + { + Bytes = bytes, + Count = bytes.Length + 1, + ExpectedException = prefix == "RSA" ? ExpectedException.ArgumentOutOfRangeException() : ExpectedException.ArgumentException(), + Offset = 0, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); + + theoryData.Add(new SignTheoryData($"{prefix}_CountPlusOffsetGreaterThanBytes") + { + Bytes = bytes, + Count = 10, + ExpectedException = prefix == "RSA" ? ExpectedException.ArgumentOutOfRangeException() : ExpectedException.ArgumentException(), + Offset = bytes.Length - 1, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); +#endif + theoryData.Add(new SignTheoryData($"{prefix}_OffsetNegative") + { + Bytes = bytes, + Count = bytes.Length, + ExpectedException = ExpectedException.ArgumentOutOfRangeException(), + Offset = -1, + SignatureProvider = CreateProvider(securityKey, algorithm) + }); + + theoryData.Add(new SignTheoryData($"{prefix}") + { + Bytes = bytes, + Count = bytes.Length, + Offset = 0, + SignatureProvider = CreateProvider(securityKey, algorithm), + Success = true + }); + + byte[] bytesOffset = new byte[bytes.Length + 10]; + Array.Copy(bytes, 0, bytesOffset, 5, bytes.Length); + theoryData.Add(new SignTheoryData($"{prefix}_Offset") + { + Bytes = bytesOffset, + Count = bytes.Length, + Offset = 5, + SignatureProvider = CreateProvider(securityKey, algorithm), + Success = true + }); + } + + public static SignatureProvider CreateProvider(SecurityKey securityKey, string algorithm) + { + if (securityKey is AsymmetricSecurityKey) + return new AsymmetricSignatureProvider(securityKey, algorithm); + + if (securityKey is SymmetricSecurityKey) + return new SymmetricSignatureProvider(securityKey, algorithm); + + throw new NotSupportedException($"Unknown securityKey type: '{securityKey}'"); + } } public class CryptoProviderFactoryTheoryData : TheoryDataBase, IDisposable @@ -1127,7 +1485,7 @@ public SignatureProviderTheoryData(string testId, string signingAlgorithm, strin public string SignatureProviderType { get; set; } - public bool VerifySpecifyingLength { get; set; } + public bool VerifyUsingLength { get; set; } } public class SymmetricSignatureProviderTheoryData : TheoryDataBase @@ -1138,6 +1496,44 @@ public SymmetricSignatureProviderTheoryData(string testId) : base(testId) { } public SecurityKey SecurityKey { get; set; } } -} -#pragma warning restore CS3016 // Arrays as attribute arguments is not CLS-compliant + public class SignTheoryData : TheoryDataBase + { + public SignTheoryData() { } + + public SignTheoryData(string testId) : base(testId) { } + + public string Algorithm { get; set; } + + public byte[] Buffer { get; set; } + + public byte[] Bytes { get; set; } + + public int Count { get; set; } + + public string HashAlgorithmString { get; set; } + + public int Offset { get; set; } = 0; + + public SecurityKey SecurityKey { get; set; } + + public byte[] Signature { get; set; } + + public SignatureProvider SignatureProvider { get; set; } + + public bool Success { get; set; } + } + + public class SignatureProvider2K : SignatureProvider + { + public SignatureProvider2K(SecurityKey key, string algorithm):base(key, algorithm){} + + public override byte[] Sign(byte[] input) => new byte[2048]; + + public override bool Verify(byte[] input, byte[] signature) => throw new NotImplementedException(); + + protected override void Dispose(bool disposing) => throw new NotImplementedException(); + + public override bool Verify(byte[] input, int inputOffset, int inputLength, byte[] signature, int signatureOffset, int signatureLength) => throw new NotImplementedException(); + } +}