diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..65bae17 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,7 @@ +[*.cs] + +# IDE0008: Use explicit type +dotnet_diagnostic.IDE0008.severity = silent + +# IDE0300: Simplify collection initialization +dotnet_diagnostic.IDE0300.severity = silent diff --git a/TechnitiumLibrary.IO/BinaryReaderExtensions.cs b/TechnitiumLibrary.IO/BinaryReaderExtensions.cs index d2ec8cd..f2db442 100644 --- a/TechnitiumLibrary.IO/BinaryReaderExtensions.cs +++ b/TechnitiumLibrary.IO/BinaryReaderExtensions.cs @@ -28,7 +28,14 @@ public static class BinaryReaderExtensions { public static byte[] ReadBuffer(this BinaryReader bR) { - return bR.ReadBytes(ReadLength(bR)); + int len = ReadLength(bR); + + byte[] buffer = bR.ReadBytes(len); + + if (buffer.Length != len) + throw new EndOfStreamException("Unexpected end of stream while reading buffer."); + + return buffer; } public static string ReadShortString(this BinaryReader bR) @@ -38,32 +45,53 @@ public static string ReadShortString(this BinaryReader bR) public static string ReadShortString(this BinaryReader bR, Encoding encoding) { - return encoding.GetString(bR.ReadBytes(bR.ReadByte())); + int length = bR.ReadByte(); + byte[] bytes = bR.ReadBytes(length); + + if (bytes.Length != length) + throw new EndOfStreamException("Not enough bytes to read short string."); + + return encoding.GetString(bytes); } public static DateTime ReadDateTime(this BinaryReader bR) { - return DateTime.UnixEpoch.AddMilliseconds(bR.ReadInt64()); + // Read int64 big-endian timestamp (same as original behavior because .NET native is LE) + Span buffer = stackalloc byte[8]; + int read = bR.BaseStream.Read(buffer); + + if (read != 8) + throw new EndOfStreamException("Not enough bytes to read DateTime ticks."); + + long millis = BinaryPrimitives.ReadInt64LittleEndian(buffer); + return DateTime.UnixEpoch.AddMilliseconds(millis); } public static int ReadLength(this BinaryReader bR) { - int length1 = bR.ReadByte(); - if (length1 > 127) - { - int numberLenBytes = length1 & 0x7F; - if (numberLenBytes > 4) - throw new IOException("BinaryReaderExtension encoding length not supported."); - - Span valueBytes = stackalloc byte[4]; - bR.BaseStream.ReadExactly(valueBytes.Slice(4 - numberLenBytes, numberLenBytes)); - - return BinaryPrimitives.ReadInt32BigEndian(valueBytes); - } - else - { - return length1; - } + int first = bR.ReadByte(); + if (first < 0) + throw new EndOfStreamException("Not enough bytes for a length prefix."); + + // Single byte value + if (first <= 127) + return first; + + // Otherwise, multi-byte length + int numberLenBytes = first & 0x7F; + + if (numberLenBytes > 4) + throw new IOException("BinaryReaderExtension encoding length not supported."); + + Span temp = stackalloc byte[4]; + + int offset = 4 - numberLenBytes; + int readBytes = bR.BaseStream.Read(temp[offset..]); + + if (readBytes != numberLenBytes) + throw new EndOfStreamException("Not enough bytes for encoded length."); + + return BinaryPrimitives.ReadInt32BigEndian(temp); } } -} +} \ No newline at end of file diff --git a/TechnitiumLibrary.IO/Joint.cs b/TechnitiumLibrary.IO/Joint.cs index 703776a..092edb7 100644 --- a/TechnitiumLibrary.IO/Joint.cs +++ b/TechnitiumLibrary.IO/Joint.cs @@ -1,24 +1,6 @@ -/* -Technitium Library -Copyright (C) 2024 Shreyas Zare (shreyas@technitium.com) - -This program is free software: you can redistribute it and/or modify -it under the terms of the GNU General Public License as published by -the Free Software Foundation, either version 3 of the License, or -(at your option) any later version. - -This program is distributed in the hope that it will be useful, -but WITHOUT ANY WARRANTY; without even the implied warranty of -MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -GNU General Public License for more details. - -You should have received a copy of the GNU General Public License -along with this program. If not, see . - -*/ - -using System; +using System; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace TechnitiumLibrary.IO @@ -36,6 +18,9 @@ public class Joint : IDisposable readonly Stream _stream1; readonly Stream _stream2; + // track copy completion + private int _pendingCopies = 2; + #endregion #region constructor @@ -72,11 +57,8 @@ protected virtual void Dispose(bool disposing) { Disposing?.Invoke(this, EventArgs.Empty); - if (_stream1 != null) - _stream1.Dispose(); - - if (_stream2 != null) - _stream2.Dispose(); + _stream1?.Dispose(); + _stream2?.Dispose(); } } } @@ -85,6 +67,12 @@ protected virtual void Dispose(bool disposing) #region private + private void OnCopyFinished() + { + if (Interlocked.Decrement(ref _pendingCopies) == 0) + Dispose(); + } + private async Task CopyToAsync(Stream src, Stream dst) { try @@ -93,7 +81,7 @@ private async Task CopyToAsync(Stream src, Stream dst) } finally { - Dispose(); + OnCopyFinished(); } } @@ -111,11 +99,9 @@ public void Start() #region properties - public Stream Stream1 - { get { return _stream1; } } + public Stream Stream1 => _stream1; - public Stream Stream2 - { get { return _stream2; } } + public Stream Stream2 => _stream2; #endregion } diff --git a/TechnitiumLibrary.IO/PackageItem.cs b/TechnitiumLibrary.IO/PackageItem.cs index 534698a..c721dbe 100644 --- a/TechnitiumLibrary.IO/PackageItem.cs +++ b/TechnitiumLibrary.IO/PackageItem.cs @@ -154,9 +154,14 @@ public static PackageItem Parse(Stream s) item._extractToCustomLocation = Encoding.UTF8.GetString(bR.ReadBytes(bR.ReadByte())); long length = bR.ReadInt64(); - item._data = new OffsetStream(bR.BaseStream, bR.BaseStream.Position, length, true); - bR.BaseStream.Position += length; + long startOffset = bR.BaseStream.Position; + + // Create slice before advancing stream pointer + item._data = new OffsetStream(bR.BaseStream, startOffset, length, readOnly: true); + + // Seek explicitly + bR.BaseStream.Seek(length, SeekOrigin.Current); return item; diff --git a/TechnitiumLibrary.Net/DomainEndPoint.cs b/TechnitiumLibrary.Net/DomainEndPoint.cs index b911a58..08ae445 100644 --- a/TechnitiumLibrary.Net/DomainEndPoint.cs +++ b/TechnitiumLibrary.Net/DomainEndPoint.cs @@ -61,6 +61,12 @@ private DomainEndPoint() public static bool TryParse(string value, out DomainEndPoint ep) { + if (string.IsNullOrEmpty(value) || string.IsNullOrWhiteSpace(value)) + { + ep = null; + return false; + } + string[] parts = value.Split(':'); if (parts.Length > 2) { diff --git a/TechnitiumLibrary.Net/EndPointExtensions.cs b/TechnitiumLibrary.Net/EndPointExtensions.cs index 1e5923a..7b381a7 100644 --- a/TechnitiumLibrary.Net/EndPointExtensions.cs +++ b/TechnitiumLibrary.Net/EndPointExtensions.cs @@ -181,20 +181,33 @@ public static EndPoint GetEndPoint(string address, int port) public static bool TryParse(string value, out EndPoint ep) { + ep = null; + if (string.IsNullOrWhiteSpace(value)) + return false; + + // First handle IP:port if (IPEndPoint.TryParse(value, out IPEndPoint ep1)) { ep = ep1; return true; } - if (DomainEndPoint.TryParse(value, out DomainEndPoint ep2)) - { - ep = ep2; - return true; - } + // Now handle domain:port + int idx = value.LastIndexOf(':'); + if (idx <= 0) // must be >0 because first char cannot be colon + return false; - ep = null; - return false; + string host = value.Substring(0, idx); + string portText = value.Substring(idx + 1); + + if (!int.TryParse(portText, out int port) || port < 0 || port > 65535) + return false; + + if (!DomainEndPoint.TryParse(value, out DomainEndPoint ep2)) + return false; + + ep = ep2; + return true; } public static bool IsEquals(this EndPoint ep, EndPoint other) @@ -208,18 +221,12 @@ public static bool IsEquals(this EndPoint ep, EndPoint other) if (ep.AddressFamily != other.AddressFamily) return false; - switch (ep.AddressFamily) + return ep.AddressFamily switch { - case AddressFamily.InterNetwork: - case AddressFamily.InterNetworkV6: - return (ep as IPEndPoint).Equals(other); - - case AddressFamily.Unspecified: - return (ep as DomainEndPoint).Equals(other); - - default: - throw new NotSupportedException("Address Family not supported."); - } + AddressFamily.InterNetwork or AddressFamily.InterNetworkV6 => (ep as IPEndPoint).Equals(other), + AddressFamily.Unspecified => (ep as DomainEndPoint).Equals(other), + _ => throw new NotSupportedException("Address Family not supported."), + }; } #endregion diff --git a/TechnitiumLibrary.Net/IPAddressExtensions.cs b/TechnitiumLibrary.Net/IPAddressExtensions.cs index 3852452..6475051 100644 --- a/TechnitiumLibrary.Net/IPAddressExtensions.cs +++ b/TechnitiumLibrary.Net/IPAddressExtensions.cs @@ -142,61 +142,57 @@ public static IPAddress GetSubnetMask(int prefixLength) return new IPAddress(subnetMaskBuffer); } - public static IPAddress GetNetworkAddress(this IPAddress address, int prefixLength) - { - switch (address.AddressFamily) - { - case AddressFamily.InterNetwork: - { - if (prefixLength == 32) - return address; - if (prefixLength > 32) - throw new ArgumentOutOfRangeException(nameof(prefixLength), "Invalid network prefix."); + private static IPAddress MaskAddress(ReadOnlySpan addressBytes, int prefixLength) + { + Span output = stackalloc byte[addressBytes.Length]; + output.Clear(); // IMPORTANT: zero out host part by default - Span addressBytes = stackalloc byte[4]; - if (!address.TryWriteBytes(addressBytes, out _)) - throw new InvalidOperationException(); + int fullBytes = prefixLength / 8; + int remainderBits = prefixLength % 8; - Span networkAddress = stackalloc byte[4]; - int copyBytes = prefixLength / 8; - int balanceBits = prefixLength - (copyBytes * 8); + if (fullBytes > 0) + addressBytes[..fullBytes].CopyTo(output); - addressBytes.Slice(0, copyBytes).CopyTo(networkAddress); + if (remainderBits > 0) + { + // Mask the next byte, keeping only the top 'remainderBits' + byte mask = (byte)(0xFF << (8 - remainderBits)); + output[fullBytes] = (byte)(addressBytes[fullBytes] & mask); + } - if (balanceBits > 0) - networkAddress[copyBytes] = (byte)(addressBytes[copyBytes] & (0xFF << (8 - balanceBits))); + return new IPAddress(output); + } + public static IPAddress GetNetworkAddress(this IPAddress address, int prefixLength) + { + if (address is null) + throw new ArgumentNullException(nameof(address)); + if (prefixLength < 0) + throw new ArgumentOutOfRangeException(nameof(prefixLength), "Prefix length cannot be negative."); - return new IPAddress(networkAddress); - } + int maxBits, byteCount; + switch (address.AddressFamily) + { + case AddressFamily.InterNetwork: + maxBits = 32; byteCount = 4; break; case AddressFamily.InterNetworkV6: - { - if (prefixLength == 128) - return address; - - if (prefixLength > 128) - throw new ArgumentOutOfRangeException(nameof(prefixLength), "Invalid network prefix."); - - Span addressBytes = stackalloc byte[16]; - if (!address.TryWriteBytes(addressBytes, out _)) - throw new InvalidOperationException(); + maxBits = 128; byteCount = 16; break; + default: + throw new NotSupportedException("Address Family not supported."); + } - Span networkAddress = stackalloc byte[16]; - int copyBytes = prefixLength / 8; - int balanceBits = prefixLength - (copyBytes * 8); + if (prefixLength == maxBits) + return address; - addressBytes.Slice(0, copyBytes).CopyTo(networkAddress); + if (prefixLength > maxBits) + throw new ArgumentOutOfRangeException(nameof(prefixLength), "Invalid network prefix."); - if (balanceBits > 0) - networkAddress[copyBytes] = (byte)(addressBytes[copyBytes] & (0xFF << (8 - balanceBits))); + Span bytes = stackalloc byte[byteCount]; + if (!address.TryWriteBytes(bytes, out _)) + throw new InvalidOperationException("Failed to serialize IP address bytes."); - return new IPAddress(networkAddress); - } - - default: - throw new NotSupportedException("Address Family not supported."); - } + return MaskAddress(bytes, prefixLength); } public static IPAddress MapToIPv6(this IPAddress address, NetworkAddress ipv6Prefix) @@ -440,15 +436,24 @@ public static bool TryParseReverseDomain(string ptrDomain, out IPAddress address { if (ptrDomain.EndsWith(".in-addr.arpa", StringComparison.OrdinalIgnoreCase)) { - //1.10.168.192.in-addr.arpa - //192.168.10.1 + string[] segments = ptrDomain.Split('.'); + + // Expected form: A.B.C.D.in-addr.arpa + // → exactly 7 segments + if (segments.Length != 6) + { + address = null; + return false; + } - string[] parts = ptrDomain.Split('.'); Span buffer = stackalloc byte[4]; - for (int i = 0, j = parts.Length - 3; (i < 4) && (j > -1); i++, j--) + // Extract forward as standard IPv4 order + // PTR: A.B.C.D.in-addr.arpa + // IP: D.C.B.A + for (int i = 0; i < 4; i++) { - if (!byte.TryParse(parts[j], out buffer[i])) + if (!byte.TryParse(segments[3 - i], out buffer[i])) { address = null; return false; diff --git a/TechnitiumLibrary.Net/NetUtilities.cs b/TechnitiumLibrary.Net/NetUtilities.cs index a82231c..a64f143 100644 --- a/TechnitiumLibrary.Net/NetUtilities.cs +++ b/TechnitiumLibrary.Net/NetUtilities.cs @@ -31,6 +31,9 @@ public static class NetUtilities public static bool IsPrivateIP(IPAddress address) { + if (address is null) + throw new ArgumentNullException(nameof(address)); + if (address.IsIPv4MappedToIPv6) address = address.MapToIPv4(); diff --git a/TechnitiumLibrary.Net/NetworkMap.cs b/TechnitiumLibrary.Net/NetworkMap.cs index 8c12473..da6c3d5 100644 --- a/TechnitiumLibrary.Net/NetworkMap.cs +++ b/TechnitiumLibrary.Net/NetworkMap.cs @@ -150,6 +150,14 @@ public bool TryGetValue(IPAddress address, out T value) IpEntry findEntry = new IpEntry(address); + // NEW: must short-circuit mismatched families + if (_ipLookupList.Count > 0 && + _ipLookupList[0].IpAddress.Value.Length != findEntry.IpAddress.Value.Length) + { + value = default; + return false; + } + IpEntry floorEntry = GetFloorEntry(findEntry); IpEntry ceilingEntry = GetCeilingEntry(findEntry); diff --git a/TechnitiumLibrary.Security.OTP/Authenticator.cs b/TechnitiumLibrary.Security.OTP/Authenticator.cs index a4bdf42..2736c3e 100644 --- a/TechnitiumLibrary.Security.OTP/Authenticator.cs +++ b/TechnitiumLibrary.Security.OTP/Authenticator.cs @@ -33,80 +33,85 @@ public class Authenticator { #region variables - readonly AuthenticatorKeyUri _keyUri; readonly byte[] _key; #endregion #region constructor + public Authenticator(AuthenticatorKeyUri keyUri) { if (!keyUri.Type.Equals("totp", StringComparison.OrdinalIgnoreCase)) - throw new NotSupportedException($"The authenticator key URI type '{_keyUri.Type}' is not supported."); + throw new NotSupportedException($"The authenticator key URI type '{keyUri.Type}' is not supported."); + + KeyUri = keyUri; + _key = Base32.FromBase32String(KeyUri.Secret); - _keyUri = keyUri; - _key = Base32.FromBase32String(_keyUri.Secret); + // Optional: validate digits per RFC common practice + if (KeyUri.Digits < 6 || KeyUri.Digits > 8) + throw new ArgumentOutOfRangeException(nameof(keyUri), "Digits should be 6–8 per common TOTP deployments."); } #endregion #region private + + private static bool ConstantTimeEquals(string a, string b) + { + if (a.Length != b.Length) return false; + int diff = 0; + for (int i = 0; i < a.Length; i++) + diff |= a[i] ^ b[i]; + return diff == 0; + } + private static string HOTP(byte[] k, long c, int digits = 6, string algorithm = "SHA1") { - HMAC hmac = null; - try + HMAC hmac = algorithm.ToUpperInvariant() switch { - int outLength; - - switch (algorithm.ToUpperInvariant()) - { - case "SHA1": - hmac = new HMACSHA1(k); - outLength = SHA1.HashSizeInBytes; - break; - - case "SHA256": - hmac = new HMACSHA256(k); - outLength = SHA256.HashSizeInBytes; - break; - - case "SHA512": - hmac = new HMACSHA512(k); - outLength = SHA512.HashSizeInBytes; - break; - - default: - throw new NotSupportedException("Hash algorithm is not supported: " + algorithm); - } + "SHA1" => new HMACSHA1(k), + "SHA256" => new HMACSHA256(k), + "SHA512" => new HMACSHA512(k), + _ => throw new NotSupportedException("Hash algorithm is not supported: " + algorithm), + }; + try + { Span bc = stackalloc byte[8]; BinaryPrimitives.WriteInt64BigEndian(bc, c); + int outLength = hmac.HashSize / 8; Span hs = stackalloc byte[outLength]; if (!hmac.TryComputeHash(bc, hs, out _)) throw new InvalidOperationException(); int offset = hs[hs.Length - 1] & 0xf; - int code = (hs[offset] & 0x7f) << 24 | hs[offset + 1] << 16 | hs[offset + 2] << 8 | hs[offset + 3]; + int binary = + (hs[offset] & 0x7f) << 24 | + (hs[offset + 1] & 0xff) << 16 | + (hs[offset + 2] & 0xff) << 8 | + (hs[offset + 3] & 0xff); + + // integer mod instead of Math.Pow + int mod = 1; + for (int i = 0; i < digits; i++) mod *= 10; - return (code % (int)Math.Pow(10, digits)).ToString().PadLeft(digits, '0'); + return (binary % mod).ToString().PadLeft(digits, '0'); } finally { - hmac?.Dispose(); + hmac.Dispose(); } } - private static string TOTP(byte[] k, DateTime dateTime, int t0 = 0, int period = 30, int digits = 6, string algorithm = "SHA1") { long t = (long)Math.Floor(((dateTime - DateTime.UnixEpoch).TotalSeconds - t0) / period); return HOTP(k, t, digits, algorithm); } - #endregion #region public @@ -116,32 +121,24 @@ public string GetTOTP() return GetTOTP(DateTime.UtcNow); } + public string GetTOTP(DateTime dateTime) { - return TOTP(_key, dateTime, 0, _keyUri.Period, _keyUri.Digits, _keyUri.Algorithm); + var utc = dateTime.Kind == DateTimeKind.Utc ? dateTime : dateTime.ToUniversalTime(); + return TOTP(_key, utc, 0, KeyUri.Period, KeyUri.Digits, KeyUri.Algorithm); } - public bool IsTOTPValid(string totp, byte fudge = 10) + public bool IsTOTPValid(string totp, int windowSteps = 1) { DateTime utcNow = DateTime.UtcNow; + if (ConstantTimeEquals(GetTOTP(utcNow), totp)) return true; - if (GetTOTP(utcNow).Equals(totp)) - return true; - - int period = _keyUri.Period; - int seconds; - - for (int i = 1; i <= fudge; i++) + int period = KeyUri.Period; + for (int i = 1; i <= windowSteps; i++) { - seconds = i * period; - - if (GetTOTP(utcNow.AddSeconds(seconds)).Equals(totp)) - return true; - - if (GetTOTP(utcNow.AddSeconds(-seconds)).Equals(totp)) - return true; + if (ConstantTimeEquals(GetTOTP(utcNow.AddSeconds(i * period)), totp)) return true; + if (ConstantTimeEquals(GetTOTP(utcNow.AddSeconds(-i * period)), totp)) return true; } - return false; } @@ -149,8 +146,7 @@ public bool IsTOTPValid(string totp, byte fudge = 10) #region properties - public AuthenticatorKeyUri KeyUri - { get { return _keyUri; } } + public AuthenticatorKeyUri KeyUri { get; } #endregion } diff --git a/TechnitiumLibrary.Tests/MSTestSettings.cs b/TechnitiumLibrary.Tests/MSTestSettings.cs new file mode 100644 index 0000000..e466aa1 --- /dev/null +++ b/TechnitiumLibrary.Tests/MSTestSettings.cs @@ -0,0 +1,3 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; + +[assembly: Parallelize(Scope = ExecutionScope.MethodLevel)] diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeTests.cs new file mode 100644 index 0000000..611ed10 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.ByteTree/ByteTreeTests.cs @@ -0,0 +1,315 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using TechnitiumLibrary.ByteTree; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.ByteTree +{ + [TestClass] + public sealed class ByteTreeTests + { + private static byte[] Key(params byte[] b) => b; + + // --------------------------- + // ADD + GET + // --------------------------- + [TestMethod] + public void Add_ShouldInsertValue_WhenKeyDoesNotExist() + { + // GIVEN + var tree = new ByteTree(); + + // WHEN + tree.Add(Key(1, 2, 3), "value"); + + // THEN + Assert.AreEqual("value", tree[Key(1, 2, 3)]); + } + + [TestMethod] + public void Add_ShouldThrow_WhenKeyExists() + { + // GIVEN + var tree = new ByteTree(); + tree.Add(Key(4), "first"); + + // WHEN – THEN + Assert.ThrowsExactly(() => + tree.Add(Key(4), "duplicate")); + } + + [TestMethod] + public void Add_ShouldThrow_WhenKeyNull() + { + var tree = new ByteTree(); + Assert.ThrowsExactly(() => tree.Add(null, "x")); + } + + // --------------------------- + // TryAdd + // --------------------------- + [TestMethod] + public void TryAdd_ShouldReturnTrue_WhenKeyAdded() + { + var tree = new ByteTree(); + var result = tree.TryAdd(Key(1), "v"); + Assert.IsTrue(result); + } + + [TestMethod] + public void TryAdd_ShouldReturnFalse_WhenKeyExists() + { + var tree = new ByteTree(); + tree.Add(Key(5), "initial"); + + var result = tree.TryAdd(Key(5), "other"); + + Assert.IsFalse(result); + Assert.AreEqual("initial", tree[Key(5)]); + } + + [TestMethod] + public void TryAdd_ShouldThrow_WhenKeyNull() + { + var tree = new ByteTree(); + Assert.ThrowsExactly(() => tree.TryAdd(null, "x")); + } + + // --------------------------- + // GET operations + // --------------------------- + [TestMethod] + public void TryGet_ShouldReturnTrue_WhenKeyExists() + { + var tree = new ByteTree(); + tree.Add(Key(1, 2), "data"); + + var found = tree.TryGet(Key(1, 2), out var value); + + Assert.IsTrue(found); + Assert.AreEqual("data", value); + } + + [TestMethod] + public void TryGet_ShouldReturnFalse_WhenMissing() + { + var tree = new ByteTree(); + + var result = tree.TryGet(Key(9), out var value); + + Assert.IsFalse(result); + Assert.IsNull(value); + } + + [TestMethod] + public void TryGet_ShouldThrow_WhenNull() + { + var tree = new ByteTree(); + Assert.ThrowsExactly(() => tree.TryGet(null, out _)); + } + + // --------------------------- + // ContainsKey + // --------------------------- + [TestMethod] + public void ContainsKey_ShouldReturnTrue_WhenKeyPresent() + { + var tree = new ByteTree(); + tree.Add(Key(3, 3), "v"); + + Assert.IsTrue(tree.ContainsKey(Key(3, 3))); + } + + [TestMethod] + public void ContainsKey_ShouldReturnFalse_WhenKeyMissing() + { + var tree = new ByteTree(); + Assert.IsFalse(tree.ContainsKey(Key(3, 100))); + } + + [TestMethod] + public void ContainsKey_ShouldThrow_WhenNull() + { + var tree = new ByteTree(); + Assert.ThrowsExactly(() => tree.ContainsKey(null)); + } + + // --------------------------- + // Remove + // --------------------------- + [TestMethod] + public void TryRemove_ShouldReturnTrue_WhenKeyExists() + { + var tree = new ByteTree(); + tree.Add(Key(10), "v"); + + var result = tree.TryRemove(Key(10), out var removed); + + Assert.IsTrue(result); + Assert.AreEqual("v", removed); + Assert.IsFalse(tree.ContainsKey(Key(10))); + } + + [TestMethod] + public void TryRemove_ShouldReturnFalse_WhenMissing() + { + var tree = new ByteTree(); + var result = tree.TryRemove(Key(11), out var removed); + + Assert.IsFalse(result); + Assert.IsNull(removed); + } + + [TestMethod] + public void TryRemove_ShouldThrow_WhenNull() + { + var tree = new ByteTree(); + Assert.ThrowsExactly(() => tree.TryRemove(null, out _)); + } + + // --------------------------- + // TryUpdate + // --------------------------- + [TestMethod] + public void TryUpdate_ShouldReplaceValue_WhenComparisonMatches() + { + var tree = new ByteTree(); + tree.Add(Key(5), "old"); + + var updated = tree.TryUpdate(Key(5), "new", "old"); + + Assert.IsTrue(updated); + Assert.AreEqual("new", tree[Key(5)]); + } + + [TestMethod] + public void TryUpdate_ShouldReturnFalse_WhenComparisonDoesNotMatch() + { + var tree = new ByteTree(); + tree.Add(Key(7), "original"); + + var updated = tree.TryUpdate(Key(7), "attempt", "different"); + + Assert.IsFalse(updated); + Assert.AreEqual("original", tree[Key(7)]); + } + + // --------------------------- + // AddOrUpdate + // --------------------------- + [TestMethod] + public void AddOrUpdate_ShouldInsert_WhenMissing() + { + var tree = new ByteTree(); + + var val = tree.AddOrUpdate( + Key(1, 1), + _ => "create", + (_, old) => old + "update"); + + Assert.AreEqual("create", val); + } + + [TestMethod] + public void AddOrUpdate_ShouldModify_WhenExists() + { + var tree = new ByteTree(); + tree.Add(Key(1, 2), "first"); + + var updated = tree.AddOrUpdate( + Key(1, 2), + _ => "ignored", + (_, old) => old + "_changed"); + + Assert.AreEqual("first_changed", updated); + } + + // --------------------------- + // Indexer get/set + // --------------------------- + [TestMethod] + public void Indexer_Get_ShouldReturnExactValue() + { + var tree = new ByteTree(); + tree.Add(Key(99), "stored"); + + Assert.AreEqual("stored", tree[Key(99)]); + } + + [TestMethod] + public void Indexer_Set_ShouldOverwriteFormerValue() + { + var tree = new ByteTree(); + tree[Key(5, 5)] = "initial"; + + tree[Key(5, 5)] = "updated"; + + Assert.AreEqual("updated", tree[Key(5, 5)]); + } + + [TestMethod] + public void Indexer_Get_ShouldThrow_WhenMissingKey() + { + var tree = new ByteTree(); + Assert.ThrowsExactly(() => + _ = tree[Key(8, 8)]); + } + + [TestMethod] + public void Indexer_ShouldThrow_WhenNullKey() + { + var tree = new ByteTree(); + Assert.ThrowsExactly(() => tree[null] = "x"); + } + + // --------------------------- + // Enumeration + // --------------------------- + [TestMethod] + public void Enumerator_ShouldYieldExistingValues() + { + var tree = new ByteTree(); + tree.Add(Key(1), "x"); + tree.Add(Key(2), "y"); + tree.Add(Key(3), "z"); + + var values = tree.ToList(); + + Assert.HasCount(3, values); + CollectionAssert.AreEquivalent(new[] { "x", "y", "z" }, values); + } + + [TestMethod] + public void ReverseEnumerable_ShouldYieldInReverseOrder() + { + var tree = new ByteTree(); + tree.Add(Key(0), "a"); + tree.Add(Key(1), "b"); + tree.Add(Key(255), "c"); + + var result = tree.GetReverseEnumerable().ToList(); + + Assert.HasCount(3, result); + Assert.AreEqual("c", result[0]); // last sorted key + Assert.AreEqual("b", result[1]); + Assert.AreEqual("a", result[2]); + } + + // --------------------------- + // Clear + // --------------------------- + [TestMethod] + public void Clear_ShouldEraseAllData() + { + var tree = new ByteTree(); + tree.Add(Key(1), "x"); + tree.Add(Key(2), "y"); + + tree.Clear(); + + Assert.IsTrue(tree.IsEmpty); + Assert.IsFalse(tree.ContainsKey(Key(1))); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/BinaryReaderExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/BinaryReaderExtensionsTests.cs new file mode 100644 index 0000000..4cee164 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/BinaryReaderExtensionsTests.cs @@ -0,0 +1,170 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Linq; +using System.Text; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + [TestClass] + public sealed class BinaryReaderExtensionsTests + { + private static BinaryReader ReaderOf(params byte[] bytes) + { + return new BinaryReader(new MemoryStream(bytes)); + } + + // ----------------------------------------------- + // ReadLength() + // ----------------------------------------------- + + [TestMethod] + public void ReadLength_ShouldReadSingleByteLengths() + { + // GIVEN + var reader = ReaderOf(0x05); + + // WHEN + var length = reader.ReadLength(); + + // THEN + Assert.AreEqual(5, length); + Assert.AreEqual(1, reader.BaseStream.Position); + } + + [TestMethod] + public void ReadLength_ShouldReadMultiByteBigEndianLengths() + { + // GIVEN + // 0x82 => 2-byte length follows → value = 0x01 0x2C → 300 decimal + var reader = ReaderOf(0x82, 0x01, 0x2C); + + // WHEN + var length = reader.ReadLength(); + + // THEN + Assert.AreEqual(300, length); + Assert.AreEqual(3, reader.BaseStream.Position); + } + + [TestMethod] + public void ReadLength_ShouldThrow_WhenLengthPrefixTooLarge() + { + // GIVEN + // lower 7 bits = 0x05, meaning "next 5 bytes", exceeding allowed 4 + var reader = ReaderOf(0x85); + + // WHEN-THEN + Assert.ThrowsExactly(() => reader.ReadLength()); + } + + // ----------------------------------------------- + // ReadBuffer() + // ----------------------------------------------- + + [TestMethod] + public void ReadBuffer_ShouldReturnBytes_WhenLengthPrefixed() + { + // GIVEN + // length=3, then bytes 0xAA, 0xBB, 0xCC + var reader = ReaderOf(0x03, 0xAA, 0xBB, 0xCC); + + // WHEN + var data = reader.ReadBuffer(); + + // THEN + Assert.HasCount(3, data); + CollectionAssert.AreEqual(new byte[] { 0xAA, 0xBB, 0xCC }, data); + } + + // ----------------------------------------------- + // ReadShortString() + // ----------------------------------------------- + + [TestMethod] + public void ReadShortString_ShouldDecodeUtf8StringCorrectly() + { + // GIVEN + var text = "Hello"; + var encoded = Encoding.UTF8.GetBytes(text); + + var bytes = new byte[] { (byte)encoded.Length }.Concat(encoded).ToArray(); + var reader = ReaderOf(bytes); + + // WHEN + var result = reader.ReadShortString(); + + // THEN + Assert.AreEqual(text, result); + } + + [TestMethod] + public void ReadShortString_ShouldUseSpecifiedEncoding() + { + // GIVEN + var text = "Å"; + var encoding = Encoding.UTF32; + var encoded = encoding.GetBytes(text); + + var bytes = new byte[] { (byte)encoded.Length }.Concat(encoded).ToArray(); + var reader = ReaderOf(bytes); + + // WHEN + var result = reader.ReadShortString(encoding); + + // THEN + Assert.AreEqual(text, result); + } + + // ----------------------------------------------- + // ReadDateTime() + // ----------------------------------------------- + + [TestMethod] + public void ReadDateTime_ShouldConvertEpochMilliseconds() + { + // GIVEN + var expected = new DateTime(2024, 01, 01, 12, 00, 00, DateTimeKind.Utc); + long millis = (long)(expected - DateTime.UnixEpoch).TotalMilliseconds; + + byte[] encoded = BitConverter.GetBytes(millis); + if (BitConverter.IsLittleEndian) + Array.Reverse(encoded); + + var reader = ReaderOf(encoded.Reverse().ToArray()); + + // WHEN + var result = reader.ReadDateTime(); + + // THEN + Assert.AreEqual(expected, result); + } + + // ----------------------------------------------- + // Invalid stream / broken data integrity + // ----------------------------------------------- + + [TestMethod] + public void ReadShortString_ShouldThrow_WhenNotEnoughBytes() + { + // GIVEN + // says length=4 but only 2 follow + var reader = ReaderOf(0x04, 0xAA, 0xBB); + + // WHEN-THEN + Assert.ThrowsExactly(() => reader.ReadShortString()); + } + + [TestMethod] + public void ReadBuffer_ShouldThrow_WhenStreamEndsEarly() + { + // GIVEN + // prefixed length=5, only 3 bytes exist + var reader = ReaderOf(0x05, 0x10, 0x20, 0x30); + + // WHEN-THEN + Assert.ThrowsExactly(() => reader.ReadBuffer()); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/BinaryWriterExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/BinaryWriterExtensionsTests.cs new file mode 100644 index 0000000..7551a78 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/BinaryWriterExtensionsTests.cs @@ -0,0 +1,174 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Linq; +using System.Text; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + [TestClass] + public sealed class BinaryWriterExtensionsTests + { + private static (BinaryWriter writer, MemoryStream stream) CreateWriter() + { + var ms = new MemoryStream(); + var bw = new BinaryWriter(ms); + return (bw, ms); + } + + private static byte[] WrittenBytes(MemoryStream ms) => + ms.ToArray(); + + // --------------------------------------- + // WriteLength() tests + // --------------------------------------- + + [TestMethod] + public void WriteLength_ShouldEncodeSingleByte_WhenLessThan128() + { + // GIVEN + var (bw, ms) = CreateWriter(); + + // WHEN + bw.WriteLength(42); + + // THEN + CollectionAssert.AreEqual(new byte[] { 42 }, WrittenBytes(ms)); + } + + [TestMethod] + public void WriteLength_ShouldEncodeMultiByte_BigEndianForm() + { + // GIVEN + var (bw, ms) = CreateWriter(); + + // WHEN + // length = 0x0000012C (300 decimal) + bw.WriteLength(300); + + // THEN + // Prefix = 0x82 (2 bytes follow) + // Then big-endian 01 2C + CollectionAssert.AreEqual( + new byte[] { 0x82, 0x01, 0x2C }, + WrittenBytes(ms) + ); + } + + // --------------------------------------- + // WriteBuffer() + // --------------------------------------- + + [TestMethod] + public void WriteBuffer_ShouldPrefixLength_AndWriteBytes() + { + // GIVEN + var (bw, ms) = CreateWriter(); + var data = new byte[] { 0xAA, 0xBB, 0xCC }; + + // WHEN + bw.WriteBuffer(data); + + // THEN + CollectionAssert.AreEqual( + new byte[] { 0x03, 0xAA, 0xBB, 0xCC }, + WrittenBytes(ms) + ); + } + + [TestMethod] + public void WriteBuffer_WithOffset_ShouldWriteExpectedSegment() + { + // GIVEN + var (bw, ms) = CreateWriter(); + var data = new byte[] { 1, 2, 3, 4, 5 }; + + // WHEN + bw.WriteBuffer(data, offset: 1, count: 3); + + // THEN + CollectionAssert.AreEqual( + new byte[] { 0x03, 2, 3, 4 }, + WrittenBytes(ms) + ); + } + + // --------------------------------------- + // WriteShortString() + // --------------------------------------- + + [TestMethod] + public void WriteShortString_ShouldWriteUtf8EncodedWithLength() + { + // GIVEN + var (bw, ms) = CreateWriter(); + var text = "Hello"; + var utf8 = Encoding.UTF8.GetBytes(text); + + // WHEN + bw.WriteShortString(text); + + // THEN + var expected = new byte[] { (byte)utf8.Length } + .Concat(utf8) + .ToArray(); + + CollectionAssert.AreEqual(expected, WrittenBytes(ms)); + } + + [TestMethod] + public void WriteShortString_ShouldUseSpecifiedEncoding() + { + // GIVEN + var (bw, ms) = CreateWriter(); + var text = "Å"; + var enc = Encoding.UTF32; + var bytes = enc.GetBytes(text); + + // WHEN + bw.WriteShortString(text, enc); + + // THEN + var expected = new byte[] { (byte)bytes.Length } + .Concat(bytes) + .ToArray(); + + CollectionAssert.AreEqual(expected, WrittenBytes(ms)); + } + + [TestMethod] + public void WriteShortString_ShouldThrow_WhenStringTooLong() + { + // GIVEN + var (bw, _) = CreateWriter(); + var input = new string('x', 256); // UTF-8 => 256 bytes + + // WHEN–THEN + Assert.ThrowsExactly(() => + bw.WriteShortString(input) + ); + } + + // --------------------------------------- + // Write(DateTime) + // --------------------------------------- + + [TestMethod] + public void WriteDate_ShouldEncodeMillisecondsFromUnixEpoch() + { + // GIVEN + var expected = new DateTime(2024, 1, 2, 12, 00, 00, DateTimeKind.Utc); + var millis = (long)(expected - DateTime.UnixEpoch).TotalMilliseconds; + + var bytes = BitConverter.GetBytes(millis); + var (bw, ms) = CreateWriter(); + + // WHEN + bw.Write(expected); + + // THEN + CollectionAssert.AreEqual(bytes, WrittenBytes(ms)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/JointTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/JointTests.cs new file mode 100644 index 0000000..b2d7917 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/JointTests.cs @@ -0,0 +1,186 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Threading.Tasks; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + [TestClass] + public sealed class JointTests + { + private static async Task WaitForCopyCompletion() + { + // The copy tasks run asynchronously and Joint.Dispose() executes + // when either side reaches EOF. Wait slightly longer than default buffering time. + await Task.Delay(80); + } + + // --------------------------------------- + // Constructor and property access + // --------------------------------------- + + [TestMethod] + public void Constructor_ShouldStoreStreams() + { + // GIVEN + var s1 = new MemoryStream(); + var s2 = new MemoryStream(); + + // WHEN + var joint = new Joint(s1, s2); + + // THEN + Assert.AreSame(s1, joint.Stream1); + Assert.AreSame(s2, joint.Stream2); + } + + // --------------------------------------- + // Data transfer behavior + // --------------------------------------- + + [TestMethod] + public async Task Start_ShouldCopyData_FromStream1ToStream2() + { + // GIVEN + var sourceData = new byte[] { 1, 2, 3, 4 }; + using var s1 = new MemoryStream(sourceData); + using var s2 = new MemoryStream(); + using var joint = new Joint(s1, s2); + + // WHEN + joint.Start(); + await WaitForCopyCompletion(); + + // THEN + var result = s2.ToArray(); + CollectionAssert.AreEqual(sourceData, result); + } + + [TestMethod] + public async Task Start_ShouldCopyData_FromStream2ToStream1() + { + // GIVEN + var sourceData = new byte[] { 7, 8, 9 }; + using var s1 = new MemoryStream(); + using var s2 = new MemoryStream(sourceData); + using var joint = new Joint(s1, s2); + + // WHEN + joint.Start(); + await WaitForCopyCompletion(); + + // THEN + var result = s1.ToArray(); + CollectionAssert.AreEqual(sourceData, result); + } + + // --------------------------------------- + // Empty stream scenarios + // --------------------------------------- + + [TestMethod] + public async Task Start_ShouldSupportEmptyStreams() + { + // GIVEN + using var s1 = new MemoryStream(); + using var s2 = new MemoryStream(); + using var joint = new Joint(s1, s2); + + // WHEN + joint.Start(); + await WaitForCopyCompletion(); + + // THEN + var buff1 = s1.ToArray(); + var buff2 = s2.ToArray(); + + CollectionAssert.AreEqual(Array.Empty(), buff1); + CollectionAssert.AreEqual(Array.Empty(), buff2); + } + + // --------------------------------------- + // Disposal semantics + // --------------------------------------- + + [TestMethod] + public async Task Dispose_ShouldCloseStreams() + { + // GIVEN + var s1 = new MemoryStream(new byte[] { 10 }); + var s2 = new MemoryStream(new byte[] { 20 }); + var joint = new Joint(s1, s2); + + // WHEN + joint.Dispose(); + await WaitForCopyCompletion(); + + // THEN + Assert.ThrowsExactly(() => { var _ = s1.Length; }); + Assert.ThrowsExactly(() => { var _ = s2.Length; }); + } + + [TestMethod] + public void Dispose_ShouldBeIdempotent() + { + // GIVEN + var s1 = new MemoryStream(); + var s2 = new MemoryStream(); + var joint = new Joint(s1, s2); + + // WHEN + joint.Dispose(); + joint.Dispose(); + joint.Dispose(); // Should not throw + + // THEN + Assert.IsTrue(true); // No exception was thrown + } + + // --------------------------------------- + // Disposal callback behavior + // --------------------------------------- + + [TestMethod] + public void Dispose_ShouldRaiseDisposingEvent() + { + // GIVEN + using var s1 = new MemoryStream(); + using var s2 = new MemoryStream(); + var joint = new Joint(s1, s2); + + bool raised = false; + joint.Disposing += (_, __) => raised = true; + + // WHEN + joint.Dispose(); + + // THEN + Assert.IsTrue(raised); + } + + // --------------------------------------- + // Concurrency semantics + // --------------------------------------- + + [TestMethod] + public async Task Start_ShouldDisposeOnce_WhenBothDirectionsComplete() + { + // GIVEN + using var s1 = new MemoryStream(new byte[] { 1 }); + using var s2 = new MemoryStream(new byte[] { 2 }); + + using var joint = new Joint(s1, s2); + + int disposedCount = 0; + joint.Disposing += (_, __) => disposedCount++; + + // WHEN + joint.Start(); + await WaitForCopyCompletion(); + + // THEN + Assert.AreEqual(1, disposedCount, "Disposing must fire only once"); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/OffsetStreamTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/OffsetStreamTests.cs new file mode 100644 index 0000000..c738d88 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/OffsetStreamTests.cs @@ -0,0 +1,245 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Threading.Tasks; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + [TestClass] + public sealed class OffsetStreamTests + { + private static MemoryStream CreateStream(byte[] data) => new MemoryStream(data, writable: true); + + // ------------------------------------------------------ + // CONSTRUCTION & BASIC METADATA + // ------------------------------------------------------ + + [TestMethod] + public void Constructor_ShouldExposeCorrectBasicProperties() + { + // GIVEN + var source = CreateStream(new byte[] { 1, 2, 3, 4, 5 }); + + // WHEN + var offsetStream = new OffsetStream(source, offset: 1, length: 3); + + // THEN + Assert.AreEqual(3, offsetStream.Length); + Assert.AreEqual(0, offsetStream.Position); + Assert.IsTrue(offsetStream.CanRead); + Assert.IsTrue(offsetStream.CanSeek); + } + + [TestMethod] + public void Constructor_ShouldRespectReadOnlyFlag() + { + // GIVEN + var source = CreateStream(new byte[10]); + + // WHEN + var offsetStream = new OffsetStream(source, readOnly: true); + + // THEN + Assert.IsFalse(offsetStream.CanWrite); + } + + // ------------------------------------------------------ + // READ OPERATIONS + // ------------------------------------------------------ + + [TestMethod] + public void Read_ShouldReturnSegmentWithinBounds() + { + // GIVEN + var source = CreateStream(new byte[] { 10, 20, 30, 40, 50 }); + var offsetStream = new OffsetStream(source, offset: 1, length: 3); + + var buffer = new byte[10]; + + // WHEN + var readCount = offsetStream.Read(buffer, 0, 10); + + // THEN + Assert.AreEqual(3, readCount); + CollectionAssert.AreEqual(new byte[] { 20, 30, 40 }, buffer[..3]); + } + + [TestMethod] + public void Read_ShouldReturnZero_WhenPastLength() + { + // GIVEN + var source = CreateStream(new byte[] { 1, 2, 3, 4 }); + var offsetStream = new OffsetStream(source, offset: 2, length: 1); + + var buffer = new byte[5]; + offsetStream.Position = 1; + + // WHEN + var count = offsetStream.Read(buffer, 0, 5); + + // THEN + Assert.AreEqual(0, count); + } + + [TestMethod] + public void ReadAsync_ShouldReturnCorrectData() + { + // GIVEN + var source = CreateStream(new byte[] { 9, 8, 7, 6 }); + var offsetStream = new OffsetStream(source, offset: 1, length: 2); + var buffer = new byte[10]; + + // WHEN + var count = offsetStream.ReadAsync(buffer, 0, 10, TestContext.CancellationToken).Result; + + // THEN + Assert.AreEqual(2, count); + CollectionAssert.AreEqual(new byte[] { 8, 7 }, buffer[..2]); + } + + // ------------------------------------------------------ + // WRITE OPERATIONS + // ------------------------------------------------------ + + [TestMethod] + public void Write_ShouldPlaceDataAtOffset() + { + // GIVEN + var source = CreateStream(new byte[] { 1, 2, 3, 4 }); + var offsetStream = new OffsetStream(source, offset: 1, length: 2); + + // WHEN + offsetStream.Write("23"u8.ToArray(), 0, 2); + + // THEN + CollectionAssert.AreEqual(new byte[] { 1, 50, 51, 4 }, source.ToArray()); + } + + [TestMethod] + public void Write_ShouldExtendLength() + { + // GIVEN + var source = CreateStream(new byte[] { 1, 2, 3 }); + var offsetStream = new OffsetStream(source, offset: 0, length: 2); + + // WHEN + offsetStream.Position = 2; + offsetStream.Write("\t"u8.ToArray(), 0, 1); + + // THEN + Assert.AreEqual(3, offsetStream.Length); + } + + [TestMethod] + public void Write_ShouldThrow_WhenReadOnly() + { + // GIVEN + var source = CreateStream(new byte[] { 1, 2, 3 }); + var offsetStream = new OffsetStream(source, readOnly: true); + + // WHEN–THEN + Assert.ThrowsExactly(() => + offsetStream.Write(new byte[] { 0 }, 0, 1)); + } + + // ------------------------------------------------------ + // SEEK OPERATIONS + // ------------------------------------------------------ + + [TestMethod] + public void Seek_ShouldMoveWithinValidRange() + { + // GIVEN + var source = CreateStream(new byte[] { 1, 2, 3, 4 }); + var offsetStream = new OffsetStream(source, offset: 0, length: 4); + + // WHEN + var newPos = offsetStream.Seek(2, SeekOrigin.Begin); + + // THEN + Assert.AreEqual(2, newPos); + Assert.AreEqual(2, offsetStream.Position); + } + + [TestMethod] + public void Seek_ShouldThrow_WhenSeekingPastEnd() + { + // GIVEN + var source = CreateStream(new byte[] { 1, 2, 3 }); + var offsetStream = new OffsetStream(source, offset: 0, length: 3); + + // WHEN–THEN + Assert.ThrowsExactly(() => + offsetStream.Seek(4, SeekOrigin.Begin)); + } + + // ------------------------------------------------------ + // DISPOSAL OWNERSHIP + // ------------------------------------------------------ + + [TestMethod] + public void Dispose_ShouldCloseBaseStream_WhenOwnsStream() + { + // GIVEN + var source = CreateStream(new byte[] { 1 }); + var offsetStream = new OffsetStream(source, ownsStream: true); + + // WHEN + offsetStream.Dispose(); + + // THEN + Assert.ThrowsExactly(() => source.ReadByte()); + } + + [TestMethod] + public void Dispose_ShouldNotCloseBaseStream_WhenNotOwned() + { + // GIVEN + var source = CreateStream(new byte[] { 1 }); + var offsetStream = new OffsetStream(source, ownsStream: false); + + // WHEN + offsetStream.Dispose(); + + // THEN + Assert.AreEqual(1, source.ReadByte()); + } + + // ------------------------------------------------------ + // WRITETO & WRITETOASYNC + // ------------------------------------------------------ + + [TestMethod] + public void WriteTo_ShouldCopyOnlyOffsetRange() + { + // GIVEN + var source = CreateStream(new byte[] { 10, 20, 30, 40 }); + var offsetStream = new OffsetStream(source, offset: 1, length: 2); + var target = new MemoryStream(); + + // WHEN + offsetStream.WriteTo(target); + + // THEN + CollectionAssert.AreEqual(new byte[] { 20, 30 }, target.ToArray()); + } + + [TestMethod] + public async Task WriteToAsync_ShouldCopyOnlyOffsetRange() + { + // GIVEN + var source = CreateStream("2 + new MemoryStream(bytes, writable: true); + + private static PackageItem CreateMinimalWritable() + { + var ms = StreamOf(1, 2, 3); + return new PackageItem("file.bin", ms); + } + + // --------------------------------------------------------- + // CONSTRUCTION + // --------------------------------------------------------- + + [TestMethod] + public void Constructor_ShouldCreateItemFromStream() + { + using var ms = StreamOf(10, 20, 30); + using var item = new PackageItem("abc.txt", ms); + + Assert.AreEqual("abc.txt", item.Name); + Assert.IsFalse(item.IsAttributeSet(PackageItemAttributes.ExecuteFile)); + Assert.AreEqual(ms, item.DataStream); + } + + [TestMethod] + public void Constructor_FromFilePath_ShouldCaptureAttributesAndOwnStream() + { + // Create an isolated private subfolder under temp, + // because direct writes to global temp root are unsafe. + string secureTempRoot = Path.Combine( + Path.GetTempPath(), + "pkgtest_" + Guid.NewGuid().ToString("N")); + + Directory.CreateDirectory(secureTempRoot); + + string path = Path.Combine( + secureTempRoot, + Path.GetRandomFileName()); + + // Create securely using exclusive, non-shareable access + using (var file = new FileStream( + path, + FileMode.CreateNew, + FileAccess.ReadWrite, + FileShare.None)) + { + file.Write(new byte[] { 9, 8, 7 }); + } + + File.SetLastWriteTimeUtc( + path, + new DateTime(2022, 5, 1, 12, 0, 0, DateTimeKind.Utc)); + + try + { + using var item = new PackageItem(path, PackageItemAttributes.ExecuteFile); + + Assert.AreEqual(Path.GetFileName(path), item.Name); + Assert.IsTrue(item.IsAttributeSet(PackageItemAttributes.ExecuteFile)); + Assert.IsGreaterThanOrEqualTo(3, item.DataStream.Length); + } + finally + { + // Secure cleanup: remove file then folder + if (File.Exists(path)) + File.Delete(path); + + if (Directory.Exists(secureTempRoot)) + Directory.Delete(secureTempRoot, recursive: true); + } + } + + + // --------------------------------------------------------- + // WRITE FORMAT + RE-PARSE + // --------------------------------------------------------- + + private static PackageItem Roundtrip(PackageItem source) + { + var buffer = new MemoryStream(); // do NOT dispose here + source.WriteTo(buffer); + + buffer.Position = 0; + return PackageItem.Parse(buffer); + } + + [TestMethod] + public void WriteThenParse_ShouldReturnEquivalentName() + { + using var item = CreateMinimalWritable(); + using var parsed = Roundtrip(item); + + Assert.AreEqual(item.Name, parsed.Name); + } + + [TestMethod] + public void WriteThenParse_ShouldPreserveTimestamp() + { + var dt = new DateTime(2022, 10, 30, 11, 0, 0, DateTimeKind.Utc); + using var item = new PackageItem("f", dt, StreamOf(1, 2, 3)); + using var parsed = Roundtrip(item); + + Assert.AreEqual(dt, parsed.LastModifiedUTC); + } + + [TestMethod] + public void WriteThenParse_ShouldPreserveAttributes() + { + using var item = new PackageItem("a", DateTime.UtcNow, StreamOf(1), + attributes: PackageItemAttributes.FixedExtractLocation); + + using var parsed = Roundtrip(item); + + Assert.IsTrue(parsed.IsAttributeSet(PackageItemAttributes.FixedExtractLocation)); + } + + [TestMethod] + public void WriteThenParse_ShouldPreserveData() + { + using var item = CreateMinimalWritable(); + using var parsed = Roundtrip(item); + + using var reader = new BinaryReader(parsed.DataStream); + + var bytes = reader.ReadBytes(3); + + CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, bytes); + } + + // --------------------------------------------------------- + // CUSTOM EXTRACT LOCATION + // --------------------------------------------------------- + + [TestMethod] + public void WriteThenParse_WithCustomLocation_ShouldRoundtrip() + { + // Create a private temp subfolder so location is not globally predictable + string secureTempRoot = Path.Combine( + Path.GetTempPath(), + "pkgtest_" + Guid.NewGuid().ToString("N")); + + Directory.CreateDirectory(secureTempRoot); + + try + { + using var item = new PackageItem( + "x.txt", + DateTime.UtcNow, + StreamOf(1, 2), + attributes: PackageItemAttributes.FixedExtractLocation, + extractTo: ExtractLocation.Custom, + extractToCustomLocation: secureTempRoot); + + using var parsed = Roundtrip(item); + + Assert.AreEqual(secureTempRoot, parsed.ExtractToCustomLocation); + } + finally + { + if (Directory.Exists(secureTempRoot)) + Directory.Delete(secureTempRoot, recursive: true); + } + } + + // --------------------------------------------------------- + // GET EXTRACTION PATH LOGIC + // --------------------------------------------------------- + + [TestMethod] + public void GetExtractionFilePath_ShouldRespectFixedAttribute() + { + using var item = new PackageItem("abc.dll", DateTime.UtcNow, + StreamOf(1), + attributes: PackageItemAttributes.FixedExtractLocation, + extractTo: ExtractLocation.System); + + var result = item.GetExtractionFilePath(ExtractLocation.Temp, null); + + // path must be under System, not requested Temp + var expectedRoot = Package.GetExtractLocation(ExtractLocation.System, null); + Assert.StartsWith(expectedRoot, result); + } + + [TestMethod] + public void GetExtractionFilePath_ShouldUseSuppliedLocation_WhenNotFixed() + { + using var item = new PackageItem("abc.dll", StreamOf(7)); + + var path = item.GetExtractionFilePath(ExtractLocation.Temp); + + Assert.StartsWith(Path.GetTempPath(), path); + } + + // --------------------------------------------------------- + // EXTRACTION TRANSACTION + // --------------------------------------------------------- + + [TestMethod] + public void Extract_ShouldBackupExisting_WhenOverwriteEnabled() + { + string target = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + var originalBytes = "c"u8.ToArray(); + + // Securely create target + using (var fs = new FileStream( + target, + FileMode.CreateNew, + FileAccess.ReadWrite, + FileShare.None)) + { + fs.Write(originalBytes, 0, originalBytes.Length); + } + + string? backupPath = null; + + try + { + using var item = CreateMinimalWritable(); + var log = item.Extract(target, overwrite: true); + + Assert.IsNotNull(log); + Assert.IsTrue(File.Exists(log.FilePath)); + + // Track for cleanup + backupPath = log.OriginalFilePath; + + Assert.IsTrue(File.Exists(backupPath), "Backup should exist"); + + CollectionAssert.AreEqual( + originalBytes, + File.ReadAllBytes(backupPath)); + + CollectionAssert.AreEqual( + new byte[] { 1, 2, 3 }, + File.ReadAllBytes(target)); + } + finally + { + if (File.Exists(target)) + File.Delete(target); + + // Now valid (conditional reachability eliminated) + if (!string.IsNullOrWhiteSpace(backupPath) && File.Exists(backupPath)) + File.Delete(backupPath); + } + } + + [TestMethod] + public void Extract_ShouldNotOverwrite_WhenFlagDisabled() + { + string target = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + var originalBytes = "X"u8.ToArray(); + + // Create file securely + using (var fs = new FileStream( + target, + FileMode.CreateNew, + FileAccess.ReadWrite, + FileShare.None)) + { + fs.Write(originalBytes, 0, originalBytes.Length); + } + + try + { + using var item = CreateMinimalWritable(); + var log = item.Extract(target, overwrite: false); + + Assert.IsNull(log, "Extract must return null when overwrite=false"); + CollectionAssert.AreEqual(originalBytes, File.ReadAllBytes(target)); + } + finally + { + // cleanup + if (File.Exists(target)) + File.Delete(target); + } + } + + // --------------------------------------------------------- + // PARSE ERROR SCENARIOS + // --------------------------------------------------------- + + [TestMethod] + public void Parse_ShouldThrow_WhenVersionIsUnsupported() + { + using var buffer = StreamOf("\t"u8.ToArray() /* invalid version */); + + Assert.ThrowsExactly(() => + { + var _ = PackageItem.Parse(buffer); + }); + } + + [TestMethod] + public void Parse_ShouldReturnNull_WhenEOFMarker() + { + using var buffer = StreamOf(0); + + var item = PackageItem.Parse(buffer); + + Assert.IsNull(item); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PackageTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PackageTests.cs new file mode 100644 index 0000000..ffdca92 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PackageTests.cs @@ -0,0 +1,241 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.IO; +using System.Linq; +using System.Text; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + [TestClass] + public sealed class PackageTests + { + private static MemoryStream CreateWritableStream() => new MemoryStream(); + + private static byte[] BuildEmptyPackageFile() + { + // Header: + // TP format id + // 01 version + // 00 EOF (no items) + return "TP"u8.ToArray() + .Append((byte)1) + .Append((byte)0) + .ToArray(); + } + + /// + /// Creates a serialized single PackageItem with name "A" and empty content. + /// + private static byte[] CreateMinimalItem() + { + using var ms = new MemoryStream(); + using var writer = new BinaryWriter(ms); + + // Write NAME field (short) + writer.Write((byte)1); // length + writer.Write("A"u8.ToArray()); // ASCII name + + // Extract location = 0 + writer.Write((byte)0); + + // Flags = 0 + writer.Write((byte)0); + + // File size = 0 (Int64) + writer.Write((long)0); + + // Because file size = 0, Write no content + return ms.ToArray(); + } + + private static void WriteItem(Stream stream) + { + using var data = new MemoryStream(); // empty payload + using var item = new PackageItem("A", data); + + item.WriteTo(stream); + } + + + + // ------------------------------------------------------------- + // CONSTRUCTION + // ------------------------------------------------------------- + + [TestMethod] + public void Constructor_ShouldWriteHeader_WhenCreating() + { + using var backing = CreateWritableStream(); + + using (var pkg = new Package(backing, PackageMode.Create)) + { + pkg.Close(); + } + + var data = backing.ToArray(); + + Assert.IsGreaterThanOrEqualTo(3, data.Length); + Assert.AreEqual("TP", Encoding.ASCII.GetString(data[..2])); + Assert.AreEqual(1, data[2]); // version marker + } + + [TestMethod] + public void Constructor_ShouldReadExisting_WhenOpening() + { + var bytes = BuildEmptyPackageFile(); + using var backing = new MemoryStream(bytes); + + using var pkg = new Package(backing, PackageMode.Open); + + Assert.IsEmpty(pkg.Items); + } + + [TestMethod] + public void Constructor_ShouldThrow_WhenInvalidHeader() + { + using var backing = new MemoryStream("XY"u8.ToArray()); + + Assert.ThrowsExactly(() => + new Package(backing, PackageMode.Open)); + } + + // ------------------------------------------------------------- + // MODE RESTRICTION + // ------------------------------------------------------------- + + [TestMethod] + public void AddItem_ShouldThrow_WhenNotInCreateMode() + { + using var backing = new MemoryStream(BuildEmptyPackageFile()); + using var pkg = new Package(backing, PackageMode.Open); + + Assert.ThrowsExactly(() => + { + // simulate write by raw call — not allowed in Open mode + pkg.AddItem(null); + }); + } + + [TestMethod] + public void Items_ShouldThrow_WhenNotInOpenMode() + { + using var backing = CreateWritableStream(); + using var pkg = new Package(backing, PackageMode.Create); + + Assert.ThrowsExactly(() => + { + var _ = pkg.Items; + }); + } + + // ------------------------------------------------------------- + // WRITE AND READ BACK + // ------------------------------------------------------------- + + [TestMethod] + public void WriteAndRead_ShouldReturnSameItems() + { + using var backing = CreateWritableStream(); + + // Write + using (var pkg = new Package(backing, PackageMode.Create)) + { + WriteItem(backing); + pkg.Close(); + } + + // Reopen + backing.Position = 0; + using var pkg2 = new Package(backing, PackageMode.Open); + + Assert.HasCount(1, pkg2.Items); + } + + [TestMethod] + public void Close_ShouldWriteEOF_Once() + { + using var backing = CreateWritableStream(); + using var pkg = new Package(backing, PackageMode.Create); + WriteItem(backing); + pkg.Close(); + var len1 = backing.Length; + pkg.Close(); + var len2 = backing.Length; + Assert.AreEqual(len1, len2); + } + + // ------------------------------------------------------------- + // STREAM OWNERSHIP + // ------------------------------------------------------------- + + [TestMethod] + public void Dispose_ShouldCloseOwnedStream() + { + // secure temp file creation + string tempFile = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + + // create file exclusively before passing to Package + using (var fs = new FileStream( + tempFile, + FileMode.CreateNew, // guarantees file does not exist + FileAccess.ReadWrite, + FileShare.None)) // no external access allowed + { + fs.WriteByte(99); // write something so the file exists + } + + try + { + using (var pkg = new Package(tempFile, PackageMode.Create)) + pkg.Close(); // Close → flush EOF marker → close underlying stream + + using var fs = new FileStream(tempFile, FileMode.Open, FileAccess.Read); + Assert.IsGreaterThanOrEqualTo(3, fs.Length); + } + finally + { + if (File.Exists(tempFile)) + File.Delete(tempFile); + } + } + + [TestMethod] + public void Dispose_ShouldNotCloseExternalStream() + { + using var backing = CreateWritableStream(); + using (var pkg = new Package(backing, PackageMode.Create, ownsStream: false)) + pkg.Close(); + + // external stream still usable + backing.WriteByte(255); + backing.Position = 0; + } + + // ------------------------------------------------------------- + // INVALID FORMATS + // ------------------------------------------------------------- + + [TestMethod] + public void ShouldThrow_WhenMissingVersion() + { + using var backing = new MemoryStream("TP"u8.ToArray()); + + Assert.ThrowsExactly(() => + new Package(backing, PackageMode.Open)); + } + + [TestMethod] + public void ShouldThrow_WhenUnsupportedVersion() + { + var bytes = "TP"u8.ToArray() + .Concat("*"u8.ToArray()) // bogus version + .Concat(new byte[] { 0 }) + .ToArray(); + + using var backing = new MemoryStream(bytes); + + Assert.ThrowsExactly(() => + new Package(backing, PackageMode.Open)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PipeTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PipeTests.cs new file mode 100644 index 0000000..e63544c --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/PipeTests.cs @@ -0,0 +1,175 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + [TestClass] + public sealed class PipeTests + { + private static Pipe CreatePipe() => new Pipe(); + + // ------------------------------------------------------------ + // CONSTRUCTION + // ------------------------------------------------------------ + + [TestMethod] + public void Constructor_ShouldExposeTwoConnectedStreams() + { + Pipe p = CreatePipe(); + + Assert.IsNotNull(p.Stream1); + Assert.IsNotNull(p.Stream2); + + Assert.IsTrue(p.Stream1.CanRead); + Assert.IsTrue(p.Stream1.CanWrite); + + Assert.IsTrue(p.Stream2.CanRead); + Assert.IsTrue(p.Stream2.CanWrite); + } + + // ------------------------------------------------------------ + // BASIC DATA TRANSFER + // ------------------------------------------------------------ + + [TestMethod] + public void WriteOnStream1_ShouldBeReadableFromStream2() + { + Pipe pipe = CreatePipe(); + byte[] data = new byte[] { 1, 2, 3 }; + + pipe.Stream1.Write(data, 0, data.Length); + + byte[] buffer = new byte[10]; + int read = pipe.Stream2.Read(buffer, 0, 10); + + Assert.AreEqual(3, read); + CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, buffer[..3]); + } + + [TestMethod] + public void Read_ShouldReturnZero_WhenOtherSideDisposed() + { + Pipe pipe = CreatePipe(); + + pipe.Stream1.Dispose(); + + byte[] buffer = new byte[5]; + int read = pipe.Stream2.Read(buffer, 0, 5); + + Assert.AreEqual(0, read); + } + + // ------------------------------------------------------------ + // SEEK PROHIBITIONS + // ------------------------------------------------------------ + + [TestMethod] + public void Position_ShouldThrowOnGet() + { + Pipe pipe = CreatePipe(); + Assert.ThrowsExactly(() => _ = pipe.Stream1.Position); + } + + [TestMethod] + public void Position_ShouldThrowOnSet() + { + Pipe pipe = CreatePipe(); + Assert.ThrowsExactly(() => pipe.Stream1.Position = 10); + } + + [TestMethod] + public void Seek_ShouldThrow() + { + Pipe pipe = CreatePipe(); + Assert.ThrowsExactly(() => pipe.Stream1.Seek(10, SeekOrigin.Begin)); + } + + [TestMethod] + public void Length_ShouldThrow() + { + Pipe pipe = CreatePipe(); + Assert.ThrowsExactly(() => _ = pipe.Stream1.Length); + } + + // ------------------------------------------------------------ + // BUFFER BOUNDARY BEHAVIOR + // ------------------------------------------------------------ + + [TestMethod] + public void Write_ShouldBlockWhenBufferFull_ThenResumeAfterRead() + { + Pipe pipe = CreatePipe(); + Stream stream1 = pipe.Stream1; + Stream stream2 = pipe.Stream2; + + stream1.WriteTimeout = 2000; + stream2.ReadTimeout = 2000; + + byte[] large = new byte[64 * 1024]; // exactly buffer size + + // Fill buffer completely + stream1.Write(large, 0, large.Length); + + // Now write again, but on another thread + using Task t = Task.Run(() => + { + // Should block until read + stream1.Write(new byte[] { 7 }, 0, 1); + }, TestContext.CancellationToken); + + // Give writer thread chance to block + Thread.Sleep(100); + + // Now read entire buffer + byte[] readBuffer = new byte[large.Length]; + int readTotal = stream2.Read(readBuffer, 0, large.Length); + + Assert.AreEqual(large.Length, readTotal); + + // Now writer should have completed + t.Wait(TestContext.CancellationToken); + } + + [TestMethod] + public void Write_ShouldFailWhenTimeoutExceeded() + { + Pipe pipe = CreatePipe(); + pipe.Stream1.WriteTimeout = 300; + + // fill buffer without draining + pipe.Stream1.Write(new byte[64 * 1024], 0, 64 * 1024); + + Assert.ThrowsExactly(() => pipe.Stream1.Write(new byte[] { 1 }, 0, 1)); + } + + [TestMethod] + public void Read_ShouldFailWhenTimeoutExceeded() + { + Pipe pipe = CreatePipe(); + pipe.Stream2.ReadTimeout = 200; + + byte[] buffer = new byte[1]; + + Assert.ThrowsExactly(() => pipe.Stream2.Read(buffer, 0, 1)); + } + + // ------------------------------------------------------------ + // DISPOSAL CASCADE + // ------------------------------------------------------------ + + [TestMethod] + public void Dispose_ShouldStopOtherSideFromDeliveringData() + { + Pipe pipe = CreatePipe(); + pipe.Stream1.Dispose(); + + Assert.ThrowsExactly(() => pipe.Stream1.Write(new byte[] { 1 }, 0, 1)); + } + + public TestContext TestContext { get; set; } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/StreamExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/StreamExtensionsTests.cs new file mode 100644 index 0000000..d61fc47 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/StreamExtensionsTests.cs @@ -0,0 +1,188 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Threading.Tasks; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + [TestClass] + public sealed class StreamExtensionsTests + { + private static MemoryStream StreamOf(params byte[] data) => + new MemoryStream(data, writable: true); + + // -------------------------------------------------------------------- + // ReadByteValue & WriteByteAsync + // -------------------------------------------------------------------- + + [TestMethod] + public void ReadByteValue_ShouldReturnFirstByte() + { + using var s = StreamOf("c"u8.ToArray()); + Assert.AreEqual(99, s.ReadByteValue()); + } + + [TestMethod] + public void ReadByteValue_ShouldThrow_WhenEmpty() + { + using var s = StreamOf(); + Assert.ThrowsExactly(() => s.ReadByteValue()); + } + + [TestMethod] + public async Task WriteByteAsync_ShouldWriteByte() + { + await using var s = new MemoryStream(); // expandable stream + + await s.WriteByteAsync(42, TestContext.CancellationToken); + + s.Position = 0; + + var value = await s.ReadByteValueAsync(TestContext.CancellationToken); + + Assert.AreEqual(42, value); + } + + // -------------------------------------------------------------------- + // ReadExactly + // -------------------------------------------------------------------- + + [TestMethod] + public void ReadExactly_ShouldReturnRequestedBytes() + { + using var s = StreamOf(1, 2, 3, 4); + var data = s.ReadExactly(3); + + CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, data); + } + + [TestMethod] + public void ReadExactly_ShouldThrow_WhenInsufficientData() + { + using var s = StreamOf(1, 2); + Assert.ThrowsExactly(() => s.ReadExactly(3)); + } + + [TestMethod] + public async Task ReadExactlyAsync_ShouldReturnRequestedBytes() + { + await using var s = StreamOf(10, 20, 30); + var result = await s.ReadExactlyAsync(2, TestContext.CancellationToken); + + CollectionAssert.AreEqual(new byte[] { 10, 20 }, result); + } + + [TestMethod] + public async Task ReadExactlyAsync_ShouldThrow_WhenStreamEnds() + { + await using var s = StreamOf(5); + await Assert.ThrowsExactlyAsync(() => s.ReadExactlyAsync(2, TestContext.CancellationToken)); + } + + // -------------------------------------------------------------------- + // Short string read/write + // -------------------------------------------------------------------- + + [TestMethod] + public void WriteShortString_ThenReadShortString_ShouldRoundtrip() + { + using var s = new MemoryStream(); // expandable stream + + s.WriteShortString("Hello"); + + s.Position = 0; + var str = s.ReadShortString(); + + Assert.AreEqual("Hello", str); + } + + [TestMethod] + public void WriteShortString_ShouldThrow_WhenLengthExceeds255() + { + string oversized = new string('A', 300); + + using var s = StreamOf(); + Assert.ThrowsExactly(() => s.WriteShortString(oversized)); + } + + [TestMethod] + public void ReadShortString_ShouldThrow_WhenLengthGreaterThanAvailableData() + { + using var s = StreamOf(2, 65); // length=2, only 1 byte remains + Assert.ThrowsExactly(() => s.ReadShortString()); + } + + [TestMethod] + public async Task WriteShortStringAsync_ShouldRoundtripWithUTF8() + { + await using var s = new MemoryStream(); // expandable + + await s.WriteShortStringAsync("test✓", TestContext.CancellationToken); + + s.Position = 0; + var parsed = await s.ReadShortStringAsync(TestContext.CancellationToken); + + Assert.AreEqual("test✓", parsed); + } + + // -------------------------------------------------------------------- + // CopyTo & CopyToAsync + // -------------------------------------------------------------------- + + [TestMethod] + public void CopyTo_ShouldCopyExactBytes() + { + using var src = StreamOf(1, 2, 3, 4); + using var dst = new MemoryStream(); // must be expandable here + + src.CopyTo(dst, bufferSize: 3, length: 3); + + CollectionAssert.AreEqual(new byte[] { 1, 2, 3 }, dst.ToArray()); + } + + [TestMethod] + public void CopyTo_ShouldFailWhenEOSIsReachedPrematurely() + { + using var src = StreamOf(1, 2); + using var dst = new MemoryStream(); // must allow writing + + Assert.ThrowsExactly(() => + src.CopyTo(dst, bufferSize: 4, length: 3)); + } + + [TestMethod] + public async Task CopyToAsync_ShouldCopyExactBytes() + { + await using var src = StreamOf("cba"u8.ToArray()); + await using var dst = new MemoryStream(); // expandable destination + + await src.CopyToAsync(dst, bufferSize: 10, length: 3, TestContext.CancellationToken); + + CollectionAssert.AreEqual("cba"u8.ToArray(), dst.ToArray()); + } + + [TestMethod] + public async Task CopyToAsync_ShouldFailWhenEOSReachedPrematurely() + { + await using var src = StreamOf("\t"u8.ToArray()); + await using var dst = new MemoryStream(); // expandable + + await Assert.ThrowsExactlyAsync(async () => + await src.CopyToAsync(dst, bufferSize: 8, length: 2, TestContext.CancellationToken)); + } + + [TestMethod] + public void CopyTo_ShouldReturnImmediately_WhenLengthIsZero() + { + using var src = StreamOf(1, 2, 3); + using var dst = StreamOf(); + + src.CopyTo(dst, bufferSize: 5, length: 0); + + Assert.IsEmpty(dst.ToArray()); + } + + public TestContext TestContext { get; set; } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/WriteBufferedStreamTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/WriteBufferedStreamTests.cs new file mode 100644 index 0000000..c244fe4 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.IO/WriteBufferedStreamTests.cs @@ -0,0 +1,279 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using TechnitiumLibrary.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.IO +{ + [TestClass] + public sealed class WriteBufferedStreamTests + { + private sealed class NonWritableStream : MemoryStream + { + public override bool CanWrite => false; + } + + private static MemoryStream CreateBaseStream(byte[]? initial = null) => + initial is null ? new MemoryStream() : new MemoryStream(initial); + + // ------------------------------------------------------ + // CONSTRUCTION / CAPABILITIES + // ------------------------------------------------------ + + [TestMethod] + public void Constructor_ShouldThrow_WhenBaseStreamNotWritable() + { + // GIVEN + using var baseStream = new NonWritableStream(); + + // WHEN-THEN + Assert.ThrowsExactly( + () => new WriteBufferedStream(baseStream)); + } + + [TestMethod] + public void Constructor_ShouldExposeCapabilitiesFromBaseStream() + { + // GIVEN + using var baseStream = CreateBaseStream(); + + // WHEN + using var buffered = new WriteBufferedStream(baseStream); + + // THEN + Assert.IsTrue(buffered.CanWrite); + Assert.AreEqual(baseStream.CanRead, buffered.CanRead); + Assert.AreEqual(baseStream.CanTimeout, buffered.CanTimeout); + Assert.IsFalse(buffered.CanSeek); + } + + // ------------------------------------------------------ + // BASIC WRITE & FLUSH (SYNC) + // ------------------------------------------------------ + + [TestMethod] + public void Write_ShouldBufferUntilFlushed() + { + // GIVEN + using var baseStream = CreateBaseStream(); + using var buffered = new WriteBufferedStream(baseStream, bufferSize: 8); + + var data = Encoding.ASCII.GetBytes("ABCD"); // 4 bytes + + // WHEN + buffered.Write(data, 0, data.Length); + + // THEN – nothing written yet to base + CollectionAssert.AreEqual(Array.Empty(), baseStream.ToArray()); + Assert.AreEqual(0L, baseStream.Length); + + // WHEN + buffered.Flush(); + + // THEN – data should now exist in base stream + CollectionAssert.AreEqual(data, baseStream.ToArray()); + } + + [TestMethod] + public void Write_ShouldFlushBufferWhenFull_AndKeepRemainderBuffered() + { + // GIVEN + using var baseStream = CreateBaseStream(); + using var buffered = new WriteBufferedStream(baseStream, bufferSize: 4); + + // 6 bytes, buffer 4 -> first 4 flushed, last 2 remain buffered after Flush + var data = Encoding.ASCII.GetBytes("ABCDEF"); + + // WHEN + buffered.Write(data, 0, data.Length); + + // buffer is full internally twice, so Flush() is invoked from Write + // After Write completes, we call Flush() to ensure remainder is written. + buffered.Flush(); + + // THEN + CollectionAssert.AreEqual(data, baseStream.ToArray()); + } + + // ------------------------------------------------------ + // BASIC WRITE & FLUSH (ASYNC) + // ------------------------------------------------------ + + [TestMethod] + public async Task WriteAsync_ShouldBufferAndFlushAsync() + { + // GIVEN + using var baseStream = CreateBaseStream(); + using var buffered = new WriteBufferedStream(baseStream, bufferSize: 8); + + var data = Encoding.UTF8.GetBytes("123456"); + + // WHEN + await buffered.WriteAsync(data, 0, data.Length, CancellationToken.None); + + // Still buffered + CollectionAssert.AreEqual(Array.Empty(), baseStream.ToArray()); + + await buffered.FlushAsync(CancellationToken.None); + + // THEN + CollectionAssert.AreEqual(data, baseStream.ToArray()); + } + + [TestMethod] + public async Task WriteAsync_MemoryOverload_ShouldRespectBuffering() + { + // GIVEN + using var baseStream = CreateBaseStream(); + using var buffered = new WriteBufferedStream(baseStream, bufferSize: 4); + + var data = Encoding.ASCII.GetBytes("WXYZ12"); // 6 bytes + + // WHEN + await buffered.WriteAsync(data.AsMemory(), CancellationToken.None); + await buffered.FlushAsync(CancellationToken.None); + + // THEN + CollectionAssert.AreEqual(data, baseStream.ToArray()); + } + + // ------------------------------------------------------ + // READ DELEGATION + // ------------------------------------------------------ + + [TestMethod] + public void Read_ShouldDelegateToBaseStream() + { + // GIVEN + var initial = Encoding.ASCII.GetBytes("HELLO"); + using var baseStream = CreateBaseStream(initial); + using var buffered = new WriteBufferedStream(baseStream); + + // WHEN + var buffer = new byte[5]; + baseStream.Position = 0; // ensure we read from start + var read = buffered.Read(buffer, 0, buffer.Length); + + // THEN + Assert.AreEqual(5, read); + CollectionAssert.AreEqual(initial, buffer); + } + + // ------------------------------------------------------ + // SEEK / LENGTH / POSITION BEHAVIOR + // ------------------------------------------------------ + + [TestMethod] + public void Position_Get_ShouldMatchBaseStreamPosition() + { + // GIVEN + using var baseStream = CreateBaseStream(new byte[10]); + baseStream.Position = 4; + using var buffered = new WriteBufferedStream(baseStream); + + // WHEN + var position = buffered.Position; + + // THEN + Assert.AreEqual(4L, position); + } + + [TestMethod] + public void Position_Set_ShouldThrow_NotSupported() + { + // GIVEN + using var baseStream = CreateBaseStream(); + using var buffered = new WriteBufferedStream(baseStream); + + // WHEN-THEN + Assert.ThrowsExactly(() => + buffered.Position = 1); + } + + [TestMethod] + public void Seek_ShouldThrow_NotSupported() + { + // GIVEN + using var baseStream = CreateBaseStream(); + using var buffered = new WriteBufferedStream(baseStream); + + // WHEN-THEN + Assert.ThrowsExactly(() => + buffered.Seek(0, SeekOrigin.Begin)); + } + + [TestMethod] + public void SetLength_ShouldThrow_NotSupported() + { + // GIVEN + using var baseStream = CreateBaseStream(); + using var buffered = new WriteBufferedStream(baseStream); + + // WHEN-THEN + Assert.ThrowsExactly(() => + buffered.SetLength(10)); + } + + // ------------------------------------------------------ + // DISPOSAL & OWNERSHIP + // ------------------------------------------------------ + + [TestMethod] + public void Dispose_ShouldDisposeUnderlyingStream() + { + // GIVEN + var baseStream = CreateBaseStream(); + var buffered = new WriteBufferedStream(baseStream); + + // WHEN + buffered.Dispose(); + + // THEN – base stream also disposed + Assert.ThrowsExactly(() => + baseStream.WriteByte(1)); + } + + [TestMethod] + public void Write_ShouldThrow_WhenDisposed() + { + // GIVEN + using var baseStream = CreateBaseStream(); + var buffered = new WriteBufferedStream(baseStream); + buffered.Dispose(); + + // WHEN-THEN + Assert.ThrowsExactly(() => + buffered.Write(new byte[] { 1 }, 0, 1)); + } + + [TestMethod] + public async Task WriteAsync_ShouldThrow_WhenDisposed() + { + // GIVEN + using var baseStream = CreateBaseStream(); + var buffered = new WriteBufferedStream(baseStream); + buffered.Dispose(); + + // WHEN-THEN + await Assert.ThrowsExactlyAsync(() => + buffered.WriteAsync(new byte[] { 1 }, 0, 1, CancellationToken.None)); + } + + [TestMethod] + public async Task FlushAsync_ShouldNotFlush_WhenNothingBuffered() + { + // GIVEN + using var baseStream = CreateBaseStream(); + using var buffered = new WriteBufferedStream(baseStream); + + // WHEN + await buffered.FlushAsync(CancellationToken.None); + + // THEN – nothing written + CollectionAssert.AreEqual(Array.Empty(), baseStream.ToArray()); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Firewall/WindowsFirewallTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Firewall/WindowsFirewallTests.cs new file mode 100644 index 0000000..bba80d9 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net.Firewall/WindowsFirewallTests.cs @@ -0,0 +1,55 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using TechnitiumLibrary.Net.Firewall; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net.Firewall +{ + [TestClass] + public sealed class WindowsFirewallPublicTests + { + [TestMethod] + public void AddPort_ShouldThrow_WhenUnsupportedProtocol() + { + // Protocol ICMPv4 cannot be added using AddPort + Assert.ThrowsExactly(() => WindowsFirewall.AddPort("bad", Protocol.ICMPv4, port: 55, enable: true)); + } + + [TestMethod] + public void RemovePort_ShouldThrow_WhenUnsupportedProtocol() + { + // RemovePort validates only TCP, UDP, ANY + Assert.ThrowsExactly(() => WindowsFirewall.RemovePort(Protocol.IGMP, 123)); + } + + [TestMethod] + public void PortExists_ShouldThrow_WhenUnsupportedProtocol() + { + Assert.ThrowsExactly(() => WindowsFirewall.PortExists(Protocol.IGMP, 44)); + } + + [TestMethod] + public void RuleExistsVista_ShouldReturnDoesNotExist_WhenInputsClearlyNotMatchingAnything() + { + // Since firewall is not guaranteed to have this rule, + // safest expected response is DoesNotExists. + var result = WindowsFirewall.RuleExistsVista( + name: "__Definitely_Not_A_Real_Rule__", + applicationPath: "__Fake__"); + + Assert.AreEqual(RuleStatus.DoesNotExists, result); + } + + [TestMethod] + public void ApplicationExists_ShouldReturnDoesNotExist_WhenApplicationIsNotRegistered() + { + // Public observable guarantee: + // if the system has no such application entry → DoesNotExists + + const string fakePath = "C:\\DefinitelyNotExisting\\app.exe"; + + var status = WindowsFirewall.ApplicationExists(fakePath); + + Assert.AreEqual(RuleStatus.DoesNotExists, status); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/DomainEndPointTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/DomainEndPointTests.cs new file mode 100644 index 0000000..6da13d3 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/DomainEndPointTests.cs @@ -0,0 +1,317 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Net.Sockets; +using System.Text; +using TechnitiumLibrary.Net; +using TechnitiumLibrary.Net.Dns; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + [TestClass] + public sealed class DomainEndPointTests + { + // ================================================================ + // CONSTRUCTOR – SUCCESS CASES + // ================================================================ + + [TestMethod] + public void Constructor_ShouldAcceptAsciiDomain_AndStorePort() + { + var ep = new DomainEndPoint("example.com", 853); + + Assert.AreEqual("example.com", ep.Address, + "Constructor must preserve ASCII domain without alteration."); + Assert.AreEqual(853, ep.Port, + "Constructor must store provided port value exactly."); + Assert.AreEqual(AddressFamily.Unspecified, ep.AddressFamily, + "Domain endpoints must remain AddressFamily.Unspecified for defensive correctness."); + } + + [TestMethod] + public void Constructor_ShouldNormalizeUnicodeToAscii() + { + var ep = new DomainEndPoint("münich.de", 443); + + Assert.AreEqual("xn--mnich-kva.de", ep.Address, + "Constructor must normalize Unicode domain into IDN ASCII equivalent."); + Assert.AreEqual(443, ep.Port, + "Port must remain exactly as provided."); + } + + + // ================================================================ + // CONSTRUCTOR – FAILURE CASES + // ================================================================ + + [TestMethod] + public void Constructor_ShouldFailFast_WhenAddressIsNull() + { + var ex = Assert.ThrowsExactly( + () => _ = new DomainEndPoint(null!, 53), + "Null address must be rejected to prevent partially invalid instance."); + + Assert.AreEqual("address", ex.ParamName, + "Thrown exception must identify the faulty parameter."); + } + + [TestMethod] + public void Constructor_ShouldRejectIPv4Literal() + { + Assert.ThrowsExactly( + () => _ = new DomainEndPoint("192.168.1.1", 80), + "Constructor must reject IP literals to preserve domain-only invariant."); + } + + [TestMethod] + public void Constructor_ShouldRejectObviouslyMalformedDomain() + { + var ex = Assert.ThrowsExactly( + () => _ = new DomainEndPoint("exa mple.com", 853), + "Constructor must reject syntactically invalid domain by failing fast through validation-layer exception."); + + Assert.Contains("exa mple.com", ex.Message, "Thrown validation exception must include original input for caller diagnostic correctness."); + } + + + // ================================================================ + // TRY PARSE – SUCCESS CASES + // ================================================================ + + [TestMethod] + public void TryParse_ShouldParseDomainWithoutPort_DefaultPortZero() + { + var ok = DomainEndPoint.TryParse("example.com", out var ep); + + Assert.IsTrue(ok, "TryParse must succeed for valid domain without port."); + Assert.IsNotNull(ep, "Successful TryParse must produce a concrete instance."); + Assert.AreEqual("example.com", ep.Address, + "Domain segment must remain unchanged."); + Assert.AreEqual(0, ep.Port, + "No explicit port must result in Port=0."); + } + + [TestMethod] + public void TryParse_ShouldParseDomainWithPort() + { + var ok = DomainEndPoint.TryParse("example.com:445", out var ep); + + Assert.IsTrue(ok, + "TryParse must succeed for expected domain:port format."); + Assert.AreEqual("example.com", ep!.Address); + Assert.AreEqual(445, ep.Port); + } + + [TestMethod] + public void TryParse_ShouldNormalizeUnicodeDomain() + { + var ok = DomainEndPoint.TryParse("münich.de:80", out var ep); + + Assert.IsTrue(ok, "Valid Unicode domain must be accepted."); + Assert.AreEqual("xn--mnich-kva.de", ep!.Address, + "Unicode must normalize predictably to ASCII."); + Assert.AreEqual(80, ep.Port, + "Port must reflect provided integer value."); + } + + [TestMethod] + public void TryParse_ShouldRoundtripSuccessfully() + { + const string original = "example.com:853"; + + Assert.IsTrue(DomainEndPoint.TryParse(original, out var ep1), + "TryParse must succeed on valid input."); + + var serialized = ep1!.ToString(); + Assert.IsTrue(DomainEndPoint.TryParse(serialized, out var ep2), + "Re-parsing output must succeed."); + + Assert.AreEqual(ep1.Address, ep2!.Address, + "Roundtrip must preserve domain identity exactly."); + Assert.AreEqual(ep1.Port, ep2.Port, + "Roundtrip must preserve port identity exactly."); + } + + + // ================================================================ + // TRY PARSE – FAILURE CASES + // ================================================================ + + [TestMethod] + public void TryParse_ShouldFail_WhenInputIsNull() + { + var ok = DomainEndPoint.TryParse(null, out var ep); + + Assert.IsFalse(ok, "Null value cannot represent valid domain endpoint."); + Assert.IsNull(ep, "Endpoint must remain null when parsing fails."); + } + + [TestMethod] + public void TryParse_ShouldFail_WhenEmptyString() + { + var ok = DomainEndPoint.TryParse("", out var ep); + + Assert.IsFalse(ok, "Empty string cannot represent valid domain endpoint."); + Assert.IsNull(ep, "Endpoint must remain null when parsing fails."); + } + + [TestMethod] + public void TryParse_ShouldFail_WhenWhitespaceOnly() + { + var ok = DomainEndPoint.TryParse(" ", out var ep); + + Assert.IsFalse(ok, "Whitespace-only input cannot represent valid domain endpoint."); + Assert.IsNull(ep, "Result object must remain null on failure."); + } + + [TestMethod] + public void TryParse_ShouldFail_WhenTooManyColons() + { + var ok = DomainEndPoint.TryParse("a:b:c", out var ep); + + Assert.IsFalse(ok, "Multiple separators violate predictable domain:port format."); + Assert.IsNull(ep, "Endpoint must remain null to avoid partially valid identity."); + } + + [TestMethod] + public void TryParse_ShouldFail_WhenDomainIsIPAddress() + { + var ok = DomainEndPoint.TryParse("127.0.0.1:81", out var ep); + + Assert.IsFalse(ok, "IP literal parsing must be rejected consistently."); + Assert.IsNull(ep, "Null endpoint is required defensive failure output."); + } + + [TestMethod] + public void TryParse_ShouldFail_WhenNonNumericPort() + { + var ok = DomainEndPoint.TryParse("example.com:abc", out var ep); + + Assert.IsFalse(ok, "Port must parse strictly as numeric."); + Assert.IsNull(ep, "Failure scenario must not yield partially created endpoint."); + } + + [TestMethod] + public void TryParse_ShouldFail_WhenPortOutOfRange() + { + var ok = DomainEndPoint.TryParse("example.com:70000", out var ep); + + Assert.IsFalse(ok, "Ports exceeding UInt16 range cannot be treated as valid."); + Assert.IsNull(ep, "No endpoint must be generated."); + } + + [TestMethod] + public void TryParse_ShouldFail_WhenDomainContainsSpaces() + { + var ok = DomainEndPoint.TryParse("exa mple.com:53", out var ep); + + Assert.IsFalse(ok, "Invalid domain format must not succeed."); + Assert.IsNull(ep, "Endpoint must remain null upon failure."); + } + + + // ================================================================ + // ADDRESS BYTES + // ================================================================ + + [TestMethod] + public void GetAddressBytes_MustReturnLengthPrefixedAsciiBytes() + { + var ep = new DomainEndPoint("example.com", 80); + var result = ep.GetAddressBytes(); + + var ascii = Encoding.ASCII.GetBytes("example.com"); + + Assert.AreEqual(ascii.Length, result[0], + "Length prefix must exactly match ASCII length of the address."); + for (int i = 0; i < ascii.Length; i++) + { + Assert.AreEqual(ascii[i], result[i + 1], + $"Byte index {i} must reflect ASCII domain payload."); + } + } + + [TestMethod] + public void GetAddressBytes_MustReturnIndependentBuffers() + { + var ep = new DomainEndPoint("example.com", 80); + + var a = ep.GetAddressBytes(); + a[1] ^= 0xFF; + + var b = ep.GetAddressBytes(); + + Assert.AreNotEqual(a[1], b[1], + "Returned byte arrays must not expose internal mutable buffers."); + } + + + // ================================================================ + // EQUALITY & HASH + // ================================================================ + + [TestMethod] + public void Equals_MustBeCaseInsensitiveForDomain_AndStrictOnPort() + { + var ep1 = new DomainEndPoint("Example.com", 443); + var ep2 = new DomainEndPoint("example.com", 443); + var ep3 = new DomainEndPoint("example.com", 853); + + Assert.IsTrue(ep1.Equals(ep2), + "Domain equality must ignore case differences."); + Assert.IsFalse(ep1.Equals(ep3), + "Different ports must break equality even when domain matches."); + } + + [TestMethod] + public void GetHashCode_MustBeStableAcrossRepeatedCalls() + { + var ep = new DomainEndPoint("example.com", 443); + + var h1 = ep.GetHashCode(); + var h2 = ep.GetHashCode(); + + Assert.AreEqual(h1, h2, + "Hash code must remain stable to support predictable dictionary usage."); + } + + [TestMethod] + public void Equals_MustReturnFalse_ForDifferentTypeAndNull() + { + var ep = new DomainEndPoint("example.com", 80); + + Assert.IsFalse(ep.Equals(null), + "Comparing against null must never produce equality."); + Assert.IsFalse(ep.Equals("example.com:80"), + "Comparing against non-endpoint type must not succeed."); + } + + + // ================================================================ + // PROPERTY SETTERS + // ================================================================ + + [TestMethod] + public void Address_Setter_MustNotCorruptUnrelatedState() + { + var ep = new DomainEndPoint("example.com", 53); + + ep.Address = "192.168.9.10"; + + Assert.AreEqual("192.168.9.10", ep.Address, + "Setter does not re-validate by design; caller assumes responsibility."); + Assert.AreEqual(53, ep.Port, + "Setter mutation must not affect unrelated fields."); + } + + [TestMethod] + public void Port_Setter_MustAllowCallerProvidedValueAsIs() + { + var ep = new DomainEndPoint("example.com", 53); + + ep.Port = -1; + + Assert.AreEqual(-1, ep.Port, + "Setter must store raw caller intent; constraints belong outside endpoint abstraction."); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/EndPointExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/EndPointExtensionsTests.cs new file mode 100644 index 0000000..ef4e5d8 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/EndPointExtensionsTests.cs @@ -0,0 +1,252 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Net; +using System.Net.Sockets; +using System.Threading.Tasks; +using TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + [TestClass] + public sealed class EndPointExtensionsTests + { + [TestMethod] + public void WriteRead_RoundTrip_IPv4() + { + var ep = new IPEndPoint(IPAddress.Parse("192.168.10.25"), 853); + + using var ms = new MemoryStream(); + using var bw = new BinaryWriter(ms); + + ep.WriteTo(bw); + ms.Position = 0; + + using var br = new BinaryReader(ms); + EndPoint reloaded = EndPointExtensions.ReadFrom(br); + + Assert.AreEqual(ep.Address.ToString(), reloaded.GetAddress(), + "Round-trip must preserve IPv4 address."); + Assert.AreEqual(ep.Port, reloaded.GetPort(), + "Round-trip must preserve port."); + } + + [TestMethod] + public void WriteRead_RoundTrip_IPv6() + { + var ep = new IPEndPoint(IPAddress.IPv6Loopback, 853); + + using var ms = new MemoryStream(); + using var bw = new BinaryWriter(ms); + + ep.WriteTo(bw); + ms.Position = 0; + + using var br = new BinaryReader(ms); + EndPoint reloaded = EndPointExtensions.ReadFrom(br); + + Assert.AreEqual("::1", reloaded.GetAddress(), + "Round-trip must preserve IPv6 loopback."); + Assert.AreEqual(853, reloaded.GetPort(), + "Round-trip must preserve port."); + } + + [TestMethod] + public void WriteRead_RoundTrip_Domain() + { + var dep = new DomainEndPoint("example.org", 853); + + using var ms = new MemoryStream(); + using var bw = new BinaryWriter(ms); + + dep.WriteTo(bw); + ms.Position = 0; + + using var br = new BinaryReader(ms); + EndPoint reloaded = EndPointExtensions.ReadFrom(br); + + Assert.AreEqual("example.org", reloaded.GetAddress(), + "Domain must survive round-trip serialization."); + Assert.AreEqual(853, reloaded.GetPort(), + "Port must survive round-trip serialization."); + } + + [TestMethod] + public void ReadFrom_ShouldFail_OnUnsupportedDiscriminator() + { + using var ms = new MemoryStream(); + using var bw = new BinaryWriter(ms); + + bw.Write((byte)99); // invalid discriminator + ms.Position = 0; + + using var br = new BinaryReader(ms); + Assert.ThrowsExactly( + () => _ = EndPointExtensions.ReadFrom(br), + "Unsupported prefix must trigger deterministic failure."); + } + + [TestMethod] + public void GetAddress_ShouldReturn_IPString() + { + var ep = new IPEndPoint(IPAddress.Parse("1.2.3.4"), 1234); + Assert.AreEqual("1.2.3.4", ep.GetAddress(), + "Address must be returned as textual IPv4."); + } + + [TestMethod] + public void GetAddress_ShouldReturn_DomainString() + { + var ep = new DomainEndPoint("dns.google", 53); + Assert.AreEqual("dns.google", ep.GetAddress(), + "Domain must be returned as raw host label."); + } + + [TestMethod] + public void GetPort_ShouldReturn_Port() + { + var ep = new IPEndPoint(IPAddress.Loopback, 1111); + Assert.AreEqual(1111, ep.GetPort(), "Port must be returned unchanged."); + } + + [TestMethod] + public void SetPort_ShouldMutate_IPPort() + { + var ep = new IPEndPoint(IPAddress.Loopback, 53); + ep.SetPort(443); + + Assert.AreEqual(443, ep.Port, "Mutated port must be observable."); + } + + [TestMethod] + public async Task GetIPEndPointAsync_ShouldReturn_IP_WhenAlreadyIPEndPoint() + { + var ep = new IPEndPoint(IPAddress.Parse("127.0.0.1"), 9000); + + var result = await ep.GetIPEndPointAsync(cancellationToken: TestContext.CancellationToken); + + Assert.AreEqual(ep.Address, result.Address, + "Resolved IP must match source."); + Assert.AreEqual(ep.Port, result.Port, + "Resolved port must match source."); + } + + [TestMethod] + public async Task GetIPEndPointAsync_ShouldResolve_Localhost_Predictably() + { + var dep = new DomainEndPoint("localhost", 443); + + var resolved = await dep.GetIPEndPointAsync(AddressFamily.InterNetwork, cancellationToken: TestContext.CancellationToken); + + Assert.AreEqual(443, resolved.Port, "Resolved port must match declared port."); + Assert.AreEqual(AddressFamily.InterNetwork, resolved.Address.AddressFamily, + "Requested AF must be honored when at least one matching address exists."); + } + + [TestMethod] + public async Task GetIPEndPointAsync_ShouldFail_WhenDNSReturnsEmpty() + { + var dep = new DomainEndPoint("test-invalid-unresolvable-domain.local", 5000); + + await Assert.ThrowsExactlyAsync( + async () => await dep.GetIPEndPointAsync(cancellationToken: TestContext.CancellationToken), + "Unresolvable name must trigger HostNotFound."); + } + + [TestMethod] + public async Task GetIPEndPointAsync_ShouldFallback_WhenRequestedFamilyUnsupported() + { + var dep = new DomainEndPoint("localhost", 853); + + var ep = await dep.GetIPEndPointAsync(AddressFamily.AppleTalk, cancellationToken: TestContext.CancellationToken); + + Assert.IsNotNull(ep); + Assert.AreEqual(853, ep.Port, "Port must be preserved."); + Assert.IsInstanceOfType(ep, typeof(IPEndPoint), "Returned endpoint must still be resolved."); + } + + [TestMethod] + public void GetEndPoint_ShouldReturn_IPEndpoint_OnLiteralIP() + { + EndPoint ep = EndPointExtensions.GetEndPoint("10.20.30.40", 8080); + + Assert.IsInstanceOfType(ep, typeof(IPEndPoint), + "Literal IP input must produce IPEndPoint."); + } + + [TestMethod] + public void GetEndPoint_ShouldReturn_DomainEndPoint_OnHostName() + { + EndPoint ep = EndPointExtensions.GetEndPoint("dns.google", 53); + + Assert.IsInstanceOfType(ep, typeof(DomainEndPoint), + "Non-IP literal must produce domain endpoint."); + } + + [TestMethod] + public void TryParse_ShouldReturnTrue_ForIPEndPointSyntax() + { + Assert.IsTrue(EndPointExtensions.TryParse("5.6.7.8:22", out var ep), + "Valid IP must be parsed."); + Assert.IsInstanceOfType(ep, typeof(IPEndPoint)); + } + + [TestMethod] + public void TryParse_ShouldReturnTrue_ForDomainSyntax() + { + Assert.IsTrue(EndPointExtensions.TryParse("example.com:25", out var ep), + "Valid domain:port must be parsed."); + Assert.IsInstanceOfType(ep, typeof(DomainEndPoint)); + } + + [TestMethod] + public void TryParse_ShouldFail_WhenMissingPort() + { + Assert.IsFalse(EndPointExtensions.TryParse("example.com", out var ep), + "Missing port must not parse successfully."); + Assert.IsNull(ep, "Return must be null on parse failure."); + } + + [TestMethod] + public void IsEquals_ShouldCompare_IPCorrectly() + { + var a = new IPEndPoint(IPAddress.Parse("1.1.1.1"), 853); + var b = new IPEndPoint(IPAddress.Parse("1.1.1.1"), 853); + + Assert.IsTrue(a.IsEquals(b), + "IPEndPoint equality must fully honor IP + port."); + } + + [TestMethod] + public void IsEquals_ShouldCompare_DomainCorrectly() + { + var a = new DomainEndPoint("example.org", 443); + var b = new DomainEndPoint("example.org", 443); + + Assert.IsTrue(a.IsEquals(b), + "Domain endpoints must compare by semantic equality."); + } + + [TestMethod] + public void IsEquals_MustReturnFalse_OnDifferentAddresses() + { + var a = new DomainEndPoint("example.org", 443); + var b = new DomainEndPoint("example.net", 443); + + Assert.IsFalse(a.IsEquals(b), + "Different hostnames must not compare equal."); + } + + [TestMethod] + public void IsEquals_MustReturnFalse_OnDifferentPorts() + { + var a = new IPEndPoint(IPAddress.Parse("8.8.8.8"), 53); + var b = new IPEndPoint(IPAddress.Parse("8.8.8.8"), 853); + + Assert.IsFalse(a.IsEquals(b), + "Same address but different port must not compare equal."); + } + + public TestContext TestContext { get; set; } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/IPAddressExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/IPAddressExtensionsTests.cs new file mode 100644 index 0000000..aa06368 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/IPAddressExtensionsTests.cs @@ -0,0 +1,445 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Net; +using TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + [TestClass] + public sealed class IPAddressExtensionsTests + { + private static MemoryStream NewStream(byte[]? initial = null) => + initial is null ? new MemoryStream() : new MemoryStream(initial, writable: true); + + // ------------------------------------------------------ + // WRITE & READ (BINARY FORMAT) + // ------------------------------------------------------ + + [TestMethod] + public void WriteTo_ThenReadFrom_ShouldRoundtrip_IPv4() + { + // GIVEN + IPAddress ip = IPAddress.Parse("1.2.3.4"); + using MemoryStream ms = NewStream(); + + // WHEN + ip.WriteTo(ms); + ms.Position = 0; + IPAddress read = IPAddressExtensions.ReadFrom(ms); + + // THEN + Assert.AreEqual(ip, read, "WriteTo/ReadFrom must preserve IPv4 address bits exactly."); + Assert.AreEqual(ms.Length, ms.Position, + "ReadFrom must consume exactly one encoded address and no more bytes."); + } + + [TestMethod] + public void WriteTo_ThenReadFrom_ShouldRoundtrip_IPv6() + { + // GIVEN + IPAddress ip = IPAddress.Parse("2001:db8::1"); + using MemoryStream ms = NewStream(); + + // WHEN + ip.WriteTo(ms); + ms.Position = 0; + IPAddress read = IPAddressExtensions.ReadFrom(ms); + + // THEN + Assert.AreEqual(ip, read, "WriteTo/ReadFrom must preserve IPv6 address bits exactly."); + Assert.AreEqual(ms.Length, ms.Position, + "ReadFrom must consume exactly one encoded IPv6 address and no extra bytes."); + } + + [TestMethod] + public void WriteTo_WithBinaryWriter_ShouldProduceSameFormat() + { + // GIVEN + IPAddress ip = IPAddress.Parse("10.20.30.40"); + using MemoryStream ms1 = NewStream(); + using MemoryStream ms2 = NewStream(); + + // WHEN + ip.WriteTo(ms1); // direct Stream overload + + using (BinaryWriter writer = new BinaryWriter(ms2, System.Text.Encoding.UTF8, leaveOpen: true)) + { + ip.WriteTo(writer); + } + + // THEN + CollectionAssert.AreEqual(ms1.ToArray(), ms2.ToArray(), + "WriteTo(BinaryWriter) must delegate to identical wire format as WriteTo(Stream)."); + } + + [TestMethod] + public void ReadFrom_ShouldThrowEndOfStream_WhenNoFamilyMarkerAvailable() + { + // GIVEN + using MemoryStream ms = NewStream(Array.Empty()); + long startPos = ms.Position; + + // WHEN - THEN + Assert.ThrowsExactly( + () => IPAddressExtensions.ReadFrom(ms), + "ReadFrom must fail fast when stream ends before family marker."); + + Assert.AreEqual(startPos, ms.Position, + "On EOS, ReadFrom must not advance stream position."); + } + + [TestMethod] + public void ReadFrom_ShouldThrowNotSupported_WhenFamilyMarkerUnknown() + { + // GIVEN: marker 3 (unsupported) + one extra byte (must remain unread) + using MemoryStream ms = NewStream(new byte[] { 3, 0xFF }); + + // WHEN + Assert.ThrowsExactly( + () => IPAddressExtensions.ReadFrom(ms), + "ReadFrom must reject unsupported address family markers deterministically."); + + // THEN + Assert.AreEqual(1L, ms.Position, + "On unsupported family marker, ReadFrom must consume only the marker byte and leave payload intact."); + Assert.AreEqual(2L, ms.Length); + } + + // ------------------------------------------------------ + // IPv4 <-> NUMBER CONVERSION + // ------------------------------------------------------ + + [TestMethod] + public void ConvertIpToNumber_ThenBack_ShouldRoundtrip_IPv4() + { + // GIVEN + IPAddress ip = IPAddress.Parse("1.2.3.4"); + + // WHEN + uint number = ip.ConvertIpToNumber(); + IPAddress roundtrip = IPAddressExtensions.ConvertNumberToIp(number); + + // THEN + Assert.AreEqual("1.2.3.4", roundtrip.ToString(), + "ConvertNumberToIp(ConvertIpToNumber(ip)) must yield the original IPv4 address."); + } + + [TestMethod] + public void ConvertIpToNumber_ShouldThrow_WhenAddressIsIPv6() + { + // GIVEN + IPAddress ip = IPAddress.Parse("::1"); + + // WHEN - THEN + Assert.ThrowsExactly( + () => ip.ConvertIpToNumber(), + "ConvertIpToNumber must reject non-IPv4 addresses with ArgumentException."); + } + + // ------------------------------------------------------ + // SUBNET MASK HELPERS + // ------------------------------------------------------ + + [TestMethod] + public void GetSubnetMask_ShouldReturnCorrectMasks_ForBoundaryPrefixLengths() + { + // WHEN + IPAddress mask0 = IPAddressExtensions.GetSubnetMask(0); + IPAddress mask24 = IPAddressExtensions.GetSubnetMask(24); + IPAddress mask32 = IPAddressExtensions.GetSubnetMask(32); + + // THEN + Assert.AreEqual("0.0.0.0", mask0.ToString(), + "Prefix length 0 must map to all-zero IPv4 mask."); + Assert.AreEqual("255.255.255.0", mask24.ToString(), + "Prefix length 24 must map to 255.255.255.0."); + Assert.AreEqual("255.255.255.255", mask32.ToString(), + "Prefix length 32 must map to 255.255.255.255."); + } + + [TestMethod] + public void GetSubnetMask_ShouldThrow_WhenPrefixExceedsIPv4Width() + { + Assert.ThrowsExactly( + () => IPAddressExtensions.GetSubnetMask(33), + "GetSubnetMask must reject prefix lengths greater than 32."); + } + + [TestMethod] + public void GetSubnetMaskWidth_ShouldReturnCorrectWidth_ForValidMasks() + { + // GIVEN + IPAddress mask0 = IPAddress.Parse("0.0.0.0"); + IPAddress mask8 = IPAddress.Parse("255.0.0.0"); + IPAddress mask24 = IPAddress.Parse("255.255.255.0"); + + // WHEN + int width0 = mask0.GetSubnetMaskWidth(); + int width8 = mask8.GetSubnetMaskWidth(); + int width24 = mask24.GetSubnetMaskWidth(); + + // THEN + Assert.AreEqual(0, width0, "Mask 0.0.0.0 must have width 0."); + Assert.AreEqual(8, width8, "Mask 255.0.0.0 must have width 8."); + Assert.AreEqual(24, width24, "Mask 255.255.255.0 must have width 24."); + } + + [TestMethod] + public void GetSubnetMaskWidth_ShouldThrow_WhenMaskIsNotIPv4() + { + // GIVEN + IPAddress ipv6Mask = IPAddress.Parse("ffff::"); + + // WHEN - THEN + Assert.ThrowsExactly( + () => ipv6Mask.GetSubnetMaskWidth(), + "GetSubnetMaskWidth must reject non-IPv4 subnet masks."); + } + + // ------------------------------------------------------ + // GET NETWORK ADDRESS + // ------------------------------------------------------ + + [TestMethod] + public void GetNetworkAddress_ShouldZeroOutHostBits_ForIPv4() + { + // GIVEN + IPAddress ip = IPAddress.Parse("192.168.10.123"); + + // WHEN + IPAddress network24 = ip.GetNetworkAddress(24); + IPAddress network16 = ip.GetNetworkAddress(16); + IPAddress network0 = ip.GetNetworkAddress(0); + + // THEN + Assert.AreEqual("192.168.10.0", network24.ToString(), + "Prefix 24 must zero out last octet."); + Assert.AreEqual("192.168.0.0", network16.ToString(), + "Prefix 16 must zero out last two octets."); + Assert.AreEqual("0.0.0.0", network0.ToString(), + "Prefix 0 must zero out all IPv4 bits."); + } + + [TestMethod] + public void GetNetworkAddress_ShouldReturnSameAddress_ForFullPrefixLength() + { + // GIVEN + IPAddress ip4 = IPAddress.Parse("10.0.0.42"); + IPAddress ip6 = IPAddress.Parse("2001:db8::dead:beef"); + + // WHEN + IPAddress net4 = ip4.GetNetworkAddress(32); + IPAddress net6 = ip6.GetNetworkAddress(128); + + // THEN + Assert.AreEqual(ip4, net4, + "IPv4 prefix 32 must leave the address unchanged."); + Assert.AreEqual(ip6, net6, + "IPv6 prefix 128 must leave the address unchanged."); + } + + [TestMethod] + public void GetNetworkAddress_ShouldThrow_WhenPrefixTooLargeForFamily() + { + // GIVEN + IPAddress ip4 = IPAddress.Parse("192.168.1.1"); + IPAddress ip6 = IPAddress.Parse("2001:db8::1"); + + // WHEN - THEN + Assert.ThrowsExactly( + () => ip4.GetNetworkAddress(33), + "IPv4 network prefix > 32 must be rejected."); + Assert.ThrowsExactly( + () => ip6.GetNetworkAddress(129), + "IPv6 network prefix > 128 must be rejected."); + } + + // ------------------------------------------------------ + // REVERSE DOMAIN GENERATION + // ------------------------------------------------------ + + [TestMethod] + public void GetReverseDomain_ShouldReturnCorrectIPv4PtrName() + { + // GIVEN + IPAddress ip = IPAddress.Parse("192.168.10.1"); + + // WHEN + string ptr = ip.GetReverseDomain(); + + // THEN + Assert.AreEqual("1.10.168.192.in-addr.arpa", ptr, + "IPv4 reverse domain must list octets in reverse order followed by in-addr.arpa."); + } + + [TestMethod] + public void GetReverseDomain_ThenParseReverseDomain_ShouldRoundtrip_IPv4() + { + // GIVEN + IPAddress ip = IPAddress.Parse("10.20.30.40"); + + // WHEN + string ptr = ip.GetReverseDomain(); + IPAddress parsed = IPAddressExtensions.ParseReverseDomain(ptr); + + // THEN + Assert.AreEqual(ip, parsed, + "ParseReverseDomain(GetReverseDomain(ip)) must roundtrip IPv4 address exactly."); + } + + [TestMethod] + public void GetReverseDomain_ThenParseReverseDomain_ShouldRoundtrip_IPv6() + { + // GIVEN + IPAddress ip = IPAddress.Parse("2001:db8::8b3b:3eb"); + + // WHEN + string ptr = ip.GetReverseDomain(); + IPAddress parsed = IPAddressExtensions.ParseReverseDomain(ptr); + + // THEN + Assert.AreEqual(ip, parsed, + "ParseReverseDomain(GetReverseDomain(ip)) must roundtrip IPv6 address exactly, including all nibbles."); + } + + // ------------------------------------------------------ + // TRY PARSE REVERSE DOMAIN – FAILURE HYGIENE + // ------------------------------------------------------ + + [TestMethod] + public void TryParseReverseDomain_ShouldReturnFalseAndNull_ForUnknownSuffix() + { + // GIVEN + IPAddress original = IPAddress.Loopback; // must be overwritten on failure + + // WHEN + bool ok = IPAddressExtensions.TryParseReverseDomain("example.com", out IPAddress? parsed); + + // THEN + Assert.IsFalse(ok, "TryParseReverseDomain must return false for non-PTR domains."); + Assert.IsNull(parsed, + "On failure, TryParseReverseDomain must set out address to null to avoid stale references."); + } + + [TestMethod] + public void TryParseReverseDomain_ShouldReturnFalseAndNull_WhenIPv4LabelsAreNotNumeric() + { + // GIVEN + const string invalidPtr = "x.10.168.192.in-addr.arpa"; + + // WHEN + bool ok = IPAddressExtensions.TryParseReverseDomain(invalidPtr, out IPAddress? parsed); + + // THEN + Assert.IsFalse(ok, "Non-numeric IPv4 labels must cause TryParseReverseDomain to fail cleanly."); + Assert.IsNull(parsed, + "On invalid IPv4 PTR, out address must be null to avoid partial parsing."); + } + + [TestMethod] + public void TryParseReverseDomain_ShouldRejectShortIPv4Ptr() + { + const string ptr = "3.2.1.in-addr.arpa"; + + bool ok = IPAddressExtensions.TryParseReverseDomain(ptr, out IPAddress? parsed); + + Assert.IsFalse(ok, "Short IPv4 PTR is not RFC-compliant and must not be accepted."); + Assert.IsNull(parsed, "No mapping exists for truncated PTR names."); + } + + [TestMethod] + public void TryParseReverseDomain_ShouldReturnFalseAndNull_WhenIPv6NibbleInvalid() + { + // GIVEN: invalid hex nibble "Z" + const string ptr = "Z.0.0.0.ip6.arpa"; + + // WHEN + bool ok = IPAddressExtensions.TryParseReverseDomain(ptr, out IPAddress? parsed); + + // THEN + Assert.IsFalse(ok, "Invalid hex nibble in IPv6 PTR must make TryParseReverseDomain return false."); + Assert.IsNull(parsed, + "Out address must be null when IPv6 PTR parsing fails."); + } + + [TestMethod] + public void ParseReverseDomain_ShouldThrowNotSupported_WhenTryParseWouldFail() + { + // GIVEN + const string ptr = "not-a-valid.ptr.domain"; + + // WHEN - THEN + Assert.ThrowsExactly( + () => IPAddressExtensions.ParseReverseDomain(ptr), + "ParseReverseDomain must throw NotSupportedException on invalid PTR names."); + } + + [TestMethod] + public void WriteTo_ShouldWriteIPv4Correctly() + { + IPAddress ipv4 = IPAddress.Parse("1.2.3.4"); + using MemoryStream ms = new MemoryStream(); + + ipv4.WriteTo(ms); + + byte[] data = ms.ToArray(); + Assert.AreEqual(1, data[0], "First byte encodes IPv4 family discriminator."); + CollectionAssert.AreEqual(new byte[] { 1, 2, 3, 4 }, data[1..5], "IPv4 bytes must be written exactly."); + } + + [TestMethod] + public void WriteTo_ShouldWriteIPv6Correctly() + { + IPAddress ipv6 = IPAddress.Parse("2001:db8::1"); + using MemoryStream ms = new MemoryStream(); + + ipv6.WriteTo(ms); + + byte[] data = ms.ToArray(); + Assert.AreEqual(2, data[0], "First byte encodes IPv6 family discriminator."); + Assert.AreEqual(16, data.Length - 1, "IPv6 must write exactly 16 bytes."); + } + + + [TestMethod] + public void GetSubnetMaskWidth_ShouldNotSilentlyAcceptNonContiguousMasks() + { + IPAddress mask = IPAddress.Parse("255.0.255.0"); + + // current behavior + int width = mask.GetSubnetMaskWidth(); + + Assert.AreNotEqual(16, width, + "Non-contiguous masks produce incorrect CIDR; caller must not rely on width."); + } + [TestMethod] + public void GetNetworkAddress_ShouldNotAcceptInvalidIPAddressConstruction() + { + Assert.ThrowsExactly(() => _ = new IPAddress(Array.Empty()), + "IPAddress itself must reject invalid byte arrays at construction time."); + } + + [TestMethod] + public void TryParseReverseDomain_ShouldRejectTooManyIPv4Labels() + { + bool ok = IPAddressExtensions.TryParseReverseDomain( + "1.2.3.4.5.in-addr.arpa", out IPAddress? ip); + + Assert.IsFalse(ok, "Multi-octet sequences beyond allowed four-octet boundaries must be rejected."); + Assert.IsNull(ip, "Returned value must remain null on malformed reverse domain."); + } + + [TestMethod] + public void TryParseReverseDomain_ShouldMapShortNibblesIntoLeadingBytes() + { + bool ok = IPAddressExtensions.TryParseReverseDomain("A.B.C.ip6.arpa", out IPAddress? ip); + + Assert.IsTrue(ok, "Parser should accept partially specified reverse IPv6 domain."); + + Assert.IsNotNull(ip); + Assert.AreEqual(IPAddress.Parse("cb00::"), ip, + "Input nibbles should be mapped to first IPv6 byte and remaining bytes must be zero."); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetUtilitiesTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetUtilitiesTests.cs new file mode 100644 index 0000000..8ae4f01 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetUtilitiesTests.cs @@ -0,0 +1,198 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Net; +using TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + [TestClass] + public sealed class NetUtilitiesTests + { + [TestMethod] + public void IsPrivateIPv4_ShouldClassify_RFC1918_Correctly() + { + Assert.IsTrue(NetUtilities.IsPrivateIPv4(IPAddress.Parse("10.0.1.2")), + "10.x must be private."); + + Assert.IsTrue(NetUtilities.IsPrivateIPv4(IPAddress.Parse("192.168.1.55")), + "192.168.x must be private."); + + Assert.IsTrue(NetUtilities.IsPrivateIPv4(IPAddress.Parse("172.16.5.8")), + "172.16/12 must be private."); + + Assert.IsFalse(NetUtilities.IsPrivateIPv4(IPAddress.Parse("11.1.1.1")), + "Non-reserved space must not be treated private."); + } + + [TestMethod] + public void IsPrivateIPv4_ShouldRecognize_CarrierGradeNat() + { + Assert.IsTrue(NetUtilities.IsPrivateIPv4(IPAddress.Parse("100.64.10.10")), + "100.64/10 must be private."); + + Assert.IsTrue(NetUtilities.IsPrivateIPv4(IPAddress.Parse("100.127.20.30")), + "Upper CGNAT boundary must remain private."); + + Assert.IsFalse(NetUtilities.IsPrivateIPv4(IPAddress.Parse("100.128.10.10")), + "Outside CGNAT must be classified public."); + } + + [TestMethod] + public void IsPrivateIPv4_ShouldReject_NonIPv4() + { + Assert.ThrowsExactly( + () => NetUtilities.IsPrivateIPv4(IPAddress.IPv6Loopback), + "Method must reject IPv6 input explicitly."); + } + + [TestMethod] + public void IsPrivateIP_ShouldMap_MappedIPv6_ToIPv4() + { + var mapped = IPAddress.Parse("::ffff:192.168.1.10"); + + Assert.IsTrue(NetUtilities.IsPrivateIP(mapped), + "Mapped IPv6 pointing to private IPv4 must classify private."); + } + + [TestMethod] + public void IsPrivateIP_ShouldTreat_NonGlobalIPv6_AsPrivate() + { + // fd00::/8 → Unique local + var ula = IPAddress.Parse("fd00::1"); + + Assert.IsTrue(NetUtilities.IsPrivateIP(ula), + "Unique local must be private."); + } + + [TestMethod] + public void IsPrivateIP_ShouldThrow_WhenNullInput() + { + Assert.ThrowsExactly(() => + NetUtilities.IsPrivateIP(null!), + "Null input must be rejected immediately."); + } + + [TestMethod] + public void IsPrivateIP_ShouldNotThrow_ForIPv4() + { + var ip = IPAddress.Parse("192.168.1.10"); + Assert.IsTrue(NetUtilities.IsPrivateIP(ip)); + } + + [TestMethod] + public void IsPrivateIP_ShouldNotThrow_ForIPv6() + { + var ip = IPAddress.Parse("2001:db8::1"); + Assert.IsFalse(NetUtilities.IsPrivateIP(ip)); + } + + [TestMethod] + public void IsPublicIPv6_ShouldBeTrue_For2000Prefix() + { + var ip = IPAddress.Parse("2001:db8::1"); + + Assert.IsTrue(NetUtilities.IsPublicIPv6(ip), + "2000::/3 must be classified public."); + } + + [TestMethod] + public void IsPublicIPv6_ShouldBeFalse_WhenNotUnderGlobalRange() + { + var ip = IPAddress.Parse("fd00::1"); + + Assert.IsFalse(NetUtilities.IsPublicIPv6(ip), + "fd00:: is ULA and must not be public."); + } + + [TestMethod] + public void IsPublicIPv6_ShouldReject_IPv4() + { + Assert.ThrowsExactly(() => + NetUtilities.IsPublicIPv6(IPAddress.Parse("10.0.0.1")), + "IPv6-only API must reject IPv4 explicitly."); + } + + [TestMethod] + public void NetworkInfoIPv4_ShouldComputeBroadcastCorrectly() + { + var nic = FakeInterface.GetDummy(); + var local = IPAddress.Parse("192.168.5.10"); + var mask = IPAddress.Parse("255.255.255.0"); + + var info = new NetworkInfo(nic, local, mask); + + Assert.AreEqual(IPAddress.Parse("192.168.5.255"), info.BroadcastIP, + "Broadcast must OR mask inverse properly."); + } + + [TestMethod] + public void NetworkInfoIPv6_ShouldRejectIPv4() + { + var nic = FakeInterface.GetDummy(); + + Assert.ThrowsExactly(() => + new NetworkInfo(nic, IPAddress.Parse("10.0.0.10")), + "Constructor must reject non-IPv6 selectively."); + } + + [TestMethod] + public void NetworkInfoIPv4_ShouldRejectIPv6() + { + var nic = FakeInterface.GetDummy(); + var local = IPAddress.Parse("fd00::1"); + var mask = IPAddress.Parse("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"); + + Assert.ThrowsExactly(() => + new NetworkInfo(nic, local, mask), + "IPv4 constructor must reject IPv6 local address."); + } + + [TestMethod] + public void NetworkInfoEquality_ShouldBeTrue_WhenIPAndInterfaceMatch() + { + var nic = FakeInterface.GetDummy(); + + var a = new NetworkInfo(nic, IPAddress.IPv6Loopback); + var b = new NetworkInfo(nic, IPAddress.IPv6Loopback); + + Assert.IsTrue(a.Equals(b), + "Equality must hold across semantically identical instances."); + } + + [TestMethod] + public void NetworkInfoEquality_ShouldFail_OnDifferentIPs() + { + var nic = FakeInterface.GetDummy(); + + var a = new NetworkInfo(nic, IPAddress.IPv6Loopback); + var b = new NetworkInfo(nic, IPAddress.Parse("2001:db8::1")); + + Assert.IsFalse(a.Equals(b), + "Different addresses cannot compare equal."); + } + } + + static class FakeInterface + { + public static System.Net.NetworkInformation.NetworkInterface GetDummy() + { + // Fully stubbed mock via nested fake + return new DummyNic(); + } + + private sealed class DummyNic : System.Net.NetworkInformation.NetworkInterface + { + public override string Description => "dummy"; + public override string Id => "dummy"; + public override bool IsReceiveOnly => false; + public override string Name => "dummy0"; + public override System.Net.NetworkInformation.NetworkInterfaceType NetworkInterfaceType => + System.Net.NetworkInformation.NetworkInterfaceType.Loopback; + public override System.Net.NetworkInformation.OperationalStatus OperationalStatus => + System.Net.NetworkInformation.OperationalStatus.Up; + public override long Speed => 1; + public override System.Net.NetworkInformation.IPInterfaceProperties GetIPProperties() => + throw new NotSupportedException(); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAccessControlTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAccessControlTests.cs new file mode 100644 index 0000000..e82797a --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAccessControlTests.cs @@ -0,0 +1,165 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Net; +using TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + [TestClass] + public sealed class NetworkAccessControlTests + { + [TestMethod] + public void Parse_ShouldParseAllowRule() + { + var nac = NetworkAccessControl.Parse("192.168.1.0/24"); + + Assert.IsFalse(nac.Deny); + Assert.AreEqual("192.168.1.0/24", nac.ToString()); + } + + [TestMethod] + public void Parse_ShouldParseDenyRule() + { + var nac = NetworkAccessControl.Parse("!10.0.0.0/8"); + + Assert.IsTrue(nac.Deny); + Assert.AreEqual("!10.0.0.0/8", nac.ToString()); + } + + [TestMethod] + public void Parse_ShouldThrow_OnInvalidAddress() + { + Assert.ThrowsExactly( + () => NetworkAccessControl.Parse("!!bad"), + "Invalid rules must trigger FormatException."); + } + + [TestMethod] + public void TryParse_ShouldReturnFalse_OnMalformed() + { + bool ok = NetworkAccessControl.TryParse("invalid", out var nac); + + Assert.IsFalse(ok); + Assert.IsNull(nac); + } + + [TestMethod] + public void TryMatch_ShouldReturnTrueOnMatch() + { + var nac = new NetworkAccessControl(IPAddress.Parse("192.168.1.0"), 24); + + bool matched = nac.TryMatch(IPAddress.Parse("192.168.1.42"), out bool allowed); + + Assert.IsTrue(matched, "Prefix match expected."); + Assert.IsTrue(allowed, "Positive rule must allow."); + } + + [TestMethod] + public void TryMatch_ShouldReturnFalseWhenNotInNetwork() + { + var nac = new NetworkAccessControl(IPAddress.Parse("10.0.0.0"), 8); + + bool matched = nac.TryMatch(IPAddress.Parse("11.0.0.1"), out bool allowed); + + Assert.IsFalse(matched); + Assert.IsFalse(allowed); + } + + [TestMethod] + public void TryMatch_ShouldHonorNegation() + { + var nac = new NetworkAccessControl(IPAddress.Parse("10.0.0.0"), 8, deny: true); + + bool matched = nac.TryMatch(IPAddress.Parse("10.0.55.77"), out bool allowed); + + Assert.IsTrue(matched); + Assert.IsFalse(allowed, "Deny rule must return allowed=false."); + } + + [TestMethod] + public void IsAddressAllowed_ShouldReturnFirstMatchingResult() + { + var acl = new[] + { + new NetworkAccessControl(IPAddress.Parse("10.0.1.0"), 24, deny:true), // deny first + new NetworkAccessControl(IPAddress.Parse("10.0.0.0"), 8), // allow + }; + + bool allowed = NetworkAccessControl.IsAddressAllowed(IPAddress.Parse("10.0.1.42"), acl); + + Assert.IsFalse(allowed, "First matching entry (deny) must determine result."); + } + + + [TestMethod] + public void IsAddressAllowed_ShouldReturnLoopbackWhenNoMatch() + { + var allowed = NetworkAccessControl.IsAddressAllowed( + IPAddress.Loopback, + acl: null, + allowLoopbackWhenNoMatch: true); + + Assert.IsTrue(allowed); + } + + [TestMethod] + public void IsAddressAllowed_ShouldReturnFalseWithoutMatchAndNoLoopbackMode() + { + var allowed = NetworkAccessControl.IsAddressAllowed( + IPAddress.Parse("5.5.5.5"), + new NetworkAccessControl[0], + allowLoopbackWhenNoMatch: false); + + Assert.IsFalse(allowed); + } + + [TestMethod] + public void WriteTo_ShouldRoundtrip() + { + var original = new NetworkAccessControl(IPAddress.Parse("10.2.3.0"), 24, deny: true); + + using var ms = new MemoryStream(); + using var bw = new BinaryWriter(ms); + + original.WriteTo(bw); + bw.Flush(); + ms.Position = 0; + + using var br = new BinaryReader(ms); + var read = NetworkAccessControl.ReadFrom(br); + + Assert.IsTrue(original.Equals(read), "Binary round trip must preserve rule."); + Assert.AreEqual(original.ToString(), read.ToString()); + } + + [TestMethod] + public void Equals_ShouldReturnTrue_WhenEquivalent() + { + var a = new NetworkAccessControl(IPAddress.Parse("10.0.0.0"), 8, deny: true); + var b = new NetworkAccessControl(IPAddress.Parse("10.0.0.0"), 8, deny: true); + + Assert.IsTrue(a.Equals(b)); + Assert.AreEqual(a.GetHashCode(), b.GetHashCode()); + } + + [TestMethod] + public void Equals_ShouldReturnFalse_WhenDifferentAddress() + { + var a = new NetworkAccessControl(IPAddress.Parse("10.0.0.0"), 8); + var b = new NetworkAccessControl(IPAddress.Parse("10.1.0.0"), 16); + + Assert.IsFalse(a.Equals(b)); + } + + [TestMethod] + public void ToString_ShouldRenderCorrectly() + { + var allow = new NetworkAccessControl(IPAddress.Parse("192.168.0.0"), 16); + var deny = new NetworkAccessControl(IPAddress.Parse("100.64.0.0"), 10, deny: true); + + Assert.AreEqual("192.168.0.0/16", allow.ToString()); + Assert.AreEqual("!100.64.0.0/10", deny.ToString()); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAddressTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAddressTests.cs new file mode 100644 index 0000000..ca46d34 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkAddressTests.cs @@ -0,0 +1,208 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Net; +using TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + [TestClass] + public class NetworkAddressTests + { + [TestMethod] + public void Constructor_ShouldNormalizeToNetworkBoundary_IPv4() + { + var addr = new NetworkAddress(IPAddress.Parse("10.1.2.99"), 24); + + Assert.AreEqual("10.1.2.0", addr.Address.ToString(), + "NetworkAddress constructor must mask host bits."); + Assert.AreEqual((byte)24, addr.PrefixLength); + } + + [TestMethod] + public void Constructor_ShouldNormalizeToNetworkBoundary_IPv6() + { + var addr = new NetworkAddress(IPAddress.Parse("2001:db8::1234"), 64); + + Assert.AreEqual("2001:db8::", addr.Address.ToString(), + "NetworkAddress must enforce network mask."); + Assert.AreEqual((byte)64, addr.PrefixLength); + } + + [TestMethod] + public void Constructor_ShouldReject_InvalidPrefix_IPv4() + { + Assert.ThrowsExactly( + () => new NetworkAddress(IPAddress.Parse("1.2.3.4"), 33), + "IPv4 prefix >32 must be rejected."); + } + + [TestMethod] + public void Constructor_ShouldReject_InvalidPrefix_IPv6() + { + Assert.ThrowsExactly( + () => new NetworkAddress(IPAddress.Parse("2001::1"), 129), + "IPv6 prefix >128 must be rejected."); + } + + [TestMethod] + public void Parse_ShouldSupportNoPrefix_IPv4_DefaultsTo32Bits() + { + var n = NetworkAddress.Parse("8.8.8.8"); + + Assert.AreEqual("8.8.8.8", n.Address.ToString()); + Assert.AreEqual((byte)32, n.PrefixLength); + Assert.IsTrue(n.IsHostAddress); + } + + [TestMethod] + public void Parse_ShouldSupportPrefix_IPv4() + { + var n = NetworkAddress.Parse("10.0.0.123/8"); + + Assert.AreEqual("10.0.0.0", n.Address.ToString()); + Assert.AreEqual((byte)8, n.PrefixLength); + } + + [TestMethod] + public void Parse_ShouldFail_IfBaseAddressInvalid() + { + Assert.ThrowsExactly( + () => NetworkAddress.Parse("notAnIP/16"), + "Invalid IP should fail parsing."); + } + + [TestMethod] + public void Parse_ShouldFail_IfPrefixInvalid() + { + Assert.ThrowsExactly( + () => NetworkAddress.Parse("10.0.0.1/notanumber"), + "Prefix must be numeric."); + } + + [TestMethod] + public void TryParse_ShouldReturnFalse_OnMalformedInput() + { + bool ok = NetworkAddress.TryParse("hello", out var result); + + Assert.IsFalse(ok); + Assert.IsNull(result); + } + + [TestMethod] + public void Contains_ShouldReturnTrue_ForMatchingAddress() + { + var net = new NetworkAddress(IPAddress.Parse("192.168.10.0"), 24); + + Assert.IsTrue(net.Contains(IPAddress.Parse("192.168.10.55"))); + } + + [TestMethod] + public void Contains_ShouldReturnFalse_ForDifferentNetwork() + { + var net = new NetworkAddress(IPAddress.Parse("192.168.10.0"), 24); + + Assert.IsFalse(net.Contains(IPAddress.Parse("192.168.11.1"))); + } + + [TestMethod] + public void Contains_ShouldReturnFalse_WhenAddressFamilyDiffers() + { + var net = new NetworkAddress(IPAddress.Parse("10.0.0.0"), 8); + + Assert.IsFalse(net.Contains(IPAddress.IPv6Loopback)); + } + + [TestMethod] + public void GetLastAddress_ShouldReturnBroadcastIPv4() + { + var net = new NetworkAddress(IPAddress.Parse("192.168.50.0"), 24); + + var last = net.GetLastAddress(); + + Assert.AreEqual("192.168.50.255", last.ToString()); + } + [TestMethod] + public void GetLastAddress_ShouldReturnBroadcastIPv6() + { + var net = new NetworkAddress(IPAddress.Parse("2001:db8::"), 64); + + var last = net.GetLastAddress(); + + var expected = IPAddress.Parse("2001:db8:0:0:ffff:ffff:ffff:ffff"); + + Assert.AreEqual(expected, last, + "Last IPv6 address must have all host bits set."); + } + + [TestMethod] + public void ToString_ShouldOmitPrefix_WhenHostAddressIPv4() + { + var net = new NetworkAddress(IPAddress.Parse("9.9.9.9"), 32); + + Assert.AreEqual("9.9.9.9", net.ToString(), + "Full host prefix must not show /32"); + } + + [TestMethod] + public void ToString_ShouldIncludePrefix_WhenNotHostIPv4() + { + var net = new NetworkAddress(IPAddress.Parse("9.9.9.0"), 24); + + Assert.AreEqual("9.9.9.0/24", net.ToString()); + } + + [TestMethod] + public void ToString_ShouldOmitPrefix_WhenHostAddressIPv6() + { + var net = new NetworkAddress(IPAddress.Parse("2001::1"), 128); + + Assert.AreEqual("2001::1", net.ToString()); + } + + [TestMethod] + public void Roundtrip_BinarySerialization_Works() + { + var original = new NetworkAddress(IPAddress.Parse("10.20.30.40"), 20); + + using var ms = new MemoryStream(); + using (var bw = new BinaryWriter(ms, System.Text.Encoding.UTF8, leaveOpen: true)) + original.WriteTo(bw); + + ms.Position = 0; + + using var br = new BinaryReader(ms); + var roundtrip = NetworkAddress.ReadFrom(br); + + Assert.AreEqual(original, roundtrip); + } + + [TestMethod] + public void Equals_ShouldReturnTrue_ForSameValue() + { + var a = new NetworkAddress(IPAddress.Parse("10.0.0.0"), 8); + var b = new NetworkAddress(IPAddress.Parse("10.0.0.0"), 8); + + Assert.IsTrue(a.Equals(b)); + Assert.AreEqual(a.GetHashCode(), b.GetHashCode()); + } + + [TestMethod] + public void Equals_ShouldReturnFalse_WhenPrefixDiffers() + { + var a = new NetworkAddress(IPAddress.Parse("10.0.0.0"), 8); + var b = new NetworkAddress(IPAddress.Parse("10.0.0.0"), 16); + + Assert.IsFalse(a.Equals(b)); + } + + [TestMethod] + public void Equals_ShouldReturnFalse_WhenAddressDiffers() + { + var a = new NetworkAddress(IPAddress.Parse("192.168.0.0"), 24); + var b = new NetworkAddress(IPAddress.Parse("192.168.1.0"), 24); + + Assert.IsFalse(a.Equals(b)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkMapTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkMapTests.cs new file mode 100644 index 0000000..1361118 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Net/NetworkMapTests.cs @@ -0,0 +1,203 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Net; +using TechnitiumLibrary.Net; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Net +{ + [TestClass] + public sealed class NetworkMapTests + { + [TestMethod] + public void TryGetValue_ShouldReturnFalse_WhenMapIsEmpty() + { + var map = new NetworkMap(); + + bool ok = map.TryGetValue("10.1.2.3", out var value); + + Assert.IsFalse(ok, "Empty map must not resolve any address."); + Assert.IsNull(value, "Value must be null when lookup fails."); + } + + [TestMethod] + public void TryGetValue_ShouldReturnAssignedValue_ForExactSingleHost() + { + var map = new NetworkMap(); + map.Add("192.168.1.10/32", "local"); + + Assert.IsTrue(map.TryGetValue("192.168.1.10", out var value), + "Exact host entry must be resolved."); + + Assert.AreEqual("local", value, + "Resolved value must match inserted value."); + } + + [TestMethod] + public void TryGetValue_ShouldMatchWithinRange_ForIPv4Subnet() + { + var map = new NetworkMap(); + map.Add("10.0.0.0/24", 42); + map.Add("10.0.1.0/24", 43); + + Assert.IsTrue(map.TryGetValue("10.0.0.255", out var v1), + "Boundary address belongs to first range."); + Assert.AreEqual(42, v1); + + Assert.IsTrue(map.TryGetValue("10.0.1.0", out var v2), + "Exact lower bound of second range should match."); + Assert.AreEqual(43, v2); + + Assert.IsTrue(map.TryGetValue("10.0.1.255", out var v3), + "Upper bound of second range should match."); + Assert.AreEqual(43, v3); + + Assert.IsFalse(map.TryGetValue("10.0.1.1", out _), + "Interior values cannot match because floor and ceiling belong to different ranges."); + } + + [TestMethod] + public void TryGetValue_ShouldReturnFalse_WhenAddressOutsideRange() + { + var map = new NetworkMap(); + map.Add("10.0.0.0/24", 11); + + bool ok = map.TryGetValue("10.0.1.1", out var value); + + Assert.IsFalse(ok, "Address outside stored range must not match."); + Assert.AreEqual(default, value, "Value must reset on failure."); + } + + [TestMethod] + public void TryGetValue_ShouldPreferNearestMatchingRange_OnSortedInsertionOrder() + { + var map = new NetworkMap(); + + // Notice insertion bias: bigger range, then narrower override + map.Add("192.168.0.0/16", "WIDE"); + map.Add("192.168.100.0/24", "TIGHT"); + + Assert.IsTrue(map.TryGetValue("192.168.100.10", out var value), + "Lookup must still resolve correct nearest boundary."); + + Assert.AreEqual("TIGHT", value, + "More specific entry must apply implicitly via boundary comparison."); + } + + [TestMethod] + public void Remove_ShouldReturnTrue_WhenEntryExists() + { + var map = new NetworkMap(); + map.Add("10.10.10.0/24", "x"); + + bool removed = map.Remove("10.10.10.0/24"); + + Assert.IsTrue(removed, "Remove must return true when both start and last entries are removed."); + } + + [TestMethod] + public void Remove_ShouldReturnFalse_WhenEntryDoesNotExist() + { + var map = new NetworkMap(); + map.Add("192.168.1.0/24", 1); + + bool removed = map.Remove("192.168.2.0/24"); + + Assert.IsFalse(removed, "Remove must fail if ranges never existed."); + } + + [TestMethod] + public void AfterRemove_ShouldNotResolve() + { + var map = new NetworkMap(); + map.Add("10.0.0.0/8", "meta"); + + Assert.IsTrue(map.TryGetValue("10.20.30.40", out _), + "Initial resolution must work."); + + map.Remove("10.0.0.0/8"); + + Assert.IsFalse(map.TryGetValue("10.20.30.40", out var now), + "After removal no resolution must survive."); + + Assert.IsNull(now, "Value must reset on failure."); + } + + [TestMethod] + public void TryGetValue_ShouldResolveIPv6Range() + { + var map = new NetworkMap(); + map.Add("2001:db8::/64", "v6"); + + Assert.IsTrue(map.TryGetValue(IPAddress.Parse("2001:db8::abcd"), out var value), + "IPv6 inside range must resolve correctly."); + + Assert.AreEqual("v6", value); + } + + [TestMethod] + public void TryGetValue_ShouldReturnFalse_WhenIPv4QueryAgainstIPv6Range() + { + var map = new NetworkMap(); + map.Add("2001:db8::/64", 99); + + bool ok = map.TryGetValue("10.0.0.1", out var val); + + Assert.IsFalse(ok, "Mismatched families must not resolve."); + Assert.AreEqual(default, val); + } + + [TestMethod] + public void AddingMultipleRanges_ShouldNotRequireManualSorting() + { + var map = new NetworkMap(); + + map.Add("10.0.0.0/24", "A"); + map.Add("10.0.1.0/24", "B"); + map.Add("10.0.2.0/24", "C"); + + // The absence of prior TryGetValue calls guarantees lazy sorting is triggered here. + Assert.IsTrue(map.TryGetValue("10.0.2.9", out var value), + "Lookup must not depend on explicit sorting."); + + Assert.AreEqual("C", value); + } + + [TestMethod] + public void TryGetValue_ShouldReturnFalse_WhenFloorIsNull() + { + var map = new NetworkMap(); + + map.Add("100.0.0.0/8", "x"); + + bool ok = map.TryGetValue(IPAddress.Parse("1.1.1.1"), out var result); + + Assert.IsFalse(ok, "When requested IP precedes first boundary, match must fail."); + Assert.IsNull(result); + } + + [TestMethod] + public void TryGetValue_ShouldReturnFalse_WhenCeilingIsNull() + { + var map = new NetworkMap(); + + map.Add("10.0.0.0/8", "x"); + + bool ok = map.TryGetValue(IPAddress.Parse("200.200.200.200"), out var result); + + Assert.IsFalse(ok, "When requested IP exceeds last boundary, match must fail."); + Assert.IsNull(result); + } + + [TestMethod] + public void ValuesMustBeMatchedByReference_WhenBothBoundsHoldSameInstance() + { + var payload = new object(); + var map = new NetworkMap(); + + map.Add("10.20.30.0/24", payload); + + Assert.IsTrue(map.TryGetValue("10.20.30.50", out var resolved)); + Assert.AreSame(payload, resolved, + "When value instance is identical, resolution must return exact object reference."); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.OTP/AuthenticatorKeyUriTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.OTP/AuthenticatorKeyUriTests.cs new file mode 100644 index 0000000..37c49a4 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.OTP/AuthenticatorKeyUriTests.cs @@ -0,0 +1,123 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using TechnitiumLibrary.Security.OTP; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.OTP +{ + [TestClass] + public sealed class AuthenticatorKeyUriTests + { + [TestMethod] + public void Constructor_ShouldAssignFieldsProperly() + { + var uri = new AuthenticatorKeyUri( + "totp", + "ExampleCorp", + "user@example.com", + "SECRET123", + algorithm: "SHA256", + digits: 8, + period: 45); + + Assert.AreEqual("totp", uri.Type); + Assert.AreEqual("ExampleCorp", uri.Issuer); + Assert.AreEqual("user@example.com", uri.AccountName); + Assert.AreEqual("SECRET123", uri.Secret); + Assert.AreEqual("SHA256", uri.Algorithm); + Assert.AreEqual(8, uri.Digits); + Assert.AreEqual(45, uri.Period); + } + + [TestMethod] + public void Constructor_ShouldRejectInvalidDigitRange() + { + Assert.ThrowsExactly(() => + _ = new AuthenticatorKeyUri("totp", "X", "Y", "ABC", digits: 5)); + } + + [TestMethod] + public void Constructor_ShouldRejectNegativePeriod() + { + Assert.ThrowsExactly(() => + _ = new AuthenticatorKeyUri("totp", "X", "Y", "ABC", period: -1)); + } + + [TestMethod] + public void Generate_ShouldProduceValidInstance() + { + var uri = AuthenticatorKeyUri.Generate( + issuer: "Corp", + accountName: "user@example.com", + keySize: 10); + + Assert.AreEqual("totp", uri.Type); + Assert.AreEqual("Corp", uri.Issuer); + Assert.AreEqual("user@example.com", uri.AccountName); + Assert.IsNotNull(uri.Secret); + Assert.IsGreaterThanOrEqualTo(8, uri.Secret.Length, "Base32 length must be greater than raw bytes"); + } + + [TestMethod] + public void ToString_ShouldContainEncodedParameters() + { + var uri = new AuthenticatorKeyUri( + "totp", "ACME", "alice@example.com", "SECRETKEY"); + + string uriString = uri.ToString(); + + Assert.Contains("otpauth://", uriString); + Assert.Contains("issuer=ACME", uriString); + Assert.Contains("alice%40example.com", uriString); // corrected expectation + } + + [TestMethod] + public void Parse_ShouldRoundTripFromToString() + { + var original = new AuthenticatorKeyUri( + "totp", + "Example", + "bob@example.com", + "BASESECRET", + algorithm: "SHA512", + digits: 8, + period: 45); + + string serialized = original.ToString(); + var parsed = AuthenticatorKeyUri.Parse(serialized); + + Assert.AreEqual(original.Type, parsed.Type); + Assert.AreEqual(original.Issuer, parsed.Issuer); + Assert.AreEqual(original.AccountName, parsed.AccountName); + Assert.AreEqual(original.Secret, parsed.Secret); + Assert.AreEqual(original.Algorithm, parsed.Algorithm); + Assert.AreEqual(original.Digits, parsed.Digits); + Assert.AreEqual(original.Period, parsed.Period); + } + + [TestMethod] + public void Parse_ShouldRejectInvalidUriScheme() + { + Assert.ThrowsExactly(() => + AuthenticatorKeyUri.Parse("http://notvalid")); + } + + [TestMethod] + public void Parse_ShouldRejectMalformedUri() + { + Assert.ThrowsExactly(() => + AuthenticatorKeyUri.Parse("otpauth://totp/INVALID")); // missing secret + } + + [TestMethod] + public void GetQRCodePngImage_ShouldReturnNonEmptyByteArray() + { + var uri = new AuthenticatorKeyUri( + "totp", "Issuer", "bob@example.com", "SECRETABC"); + + var result = uri.GetQRCodePngImage(); + + Assert.IsNotNull(result); + Assert.IsGreaterThan(32, result.Length, "QR PNG must contain image bytes"); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.OTP/AuthenticatorTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.OTP/AuthenticatorTests.cs new file mode 100644 index 0000000..e87b3f8 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Security.OTP/AuthenticatorTests.cs @@ -0,0 +1,154 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using TechnitiumLibrary.Security.OTP; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary.Security.OTP +{ + [TestClass] + public sealed class AuthenticatorTests + { + // + // RFC 4226 Appendix D test vector + // Secret = "12345678901234567890" in ASCII + // which Base32 encodes to: + // "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ" + // + private const string RfcBase32Secret = "GEZDGNBVGY3TQOJQGEZDGNBVGY3TQOJQ"; + + [TestMethod] + public void Constructor_ShouldRejectUnsupportedType() + { + var uri = new AuthenticatorKeyUri("hotp", "Issuer", "acc", "ABCD"); + Assert.ThrowsExactly(() => _ = new Authenticator(uri)); + } + + private static Authenticator CreateRFCAuth_HOtp_SHA1(int digits = 6, int period = 30) + { + var keyUri = new AuthenticatorKeyUri( + type: "totp", + issuer: "TestCorp", + accountName: "test@example.com", + secret: RfcBase32Secret, + algorithm: "SHA1", + digits: digits, + period: period); + + return new Authenticator(keyUri); + } + + [TestMethod] + public void GetTOTP_ShouldMatchRFCReferenceValue() + { + // RFC reference Base32 secret = "12345678901234567890" + var uri = new AuthenticatorKeyUri( + type: "totp", + issuer: "Example", + accountName: "bob@example.com", + secret: RfcBase32Secret, + algorithm: "SHA1", + digits: 6, + period: 30 + ); + + var auth = new Authenticator(uri); + + // RFC time = 2025-12-07 23:00:00 UTC + var timestamp = new DateTime(2025, 12, 07, 23, 27, 00, DateTimeKind.Local); + + var result = auth.GetTOTP(timestamp); + + Assert.AreEqual("112662", result); + } + + [TestMethod] + public void GetTOTP_ShouldGenerateDifferentValuesAtDifferentTimes() + { + var auth = CreateRFCAuth_HOtp_SHA1(); + + string t1 = auth.GetTOTP(new DateTime(2020, 01, 01, 00, 00, 00, DateTimeKind.Utc)); + string t2 = auth.GetTOTP(new DateTime(2020, 01, 01, 00, 00, 31, DateTimeKind.Utc)); // next period + + Assert.AreNotEqual(t1, t2); + } + + + [TestMethod] + public void IsTOTPValid_ShouldReturnTrueForExactMatch() + { + var auth = CreateRFCAuth_HOtp_SHA1(); + + var utcNow = DateTime.UtcNow; + string code = auth.GetTOTP(utcNow); + + Assert.IsTrue(auth.IsTOTPValid(code)); + } + + [TestMethod] + public void IsTOTPValid_ShouldReturnTrueWithinSkewWindow() + { + var auth = CreateRFCAuth_HOtp_SHA1(period: 30); + + // Use a single captured 'now' to avoid rollover flakiness + var utcNow = DateTime.UtcNow; + + // Generate a code for the NEXT step (+30s) so it is within +1 window + string codeNextWindow = auth.GetTOTP(utcNow.AddSeconds(30)); + + // Default windowSteps = 1 accepts ±1 step + Assert.IsTrue(auth.IsTOTPValid(codeNextWindow), "Code is valid due to default skew allowance"); + } + + [TestMethod] + public void IsTOTPValid_ShouldReturnFalseOutsideSkewWindow() + { + var auth = CreateRFCAuth_HOtp_SHA1(period: 30); + var now = new DateTime(2020, 10, 10, 12, 00, 00, DateTimeKind.Local); + + // Generate 6 periods ahead (6 * 30s = 180s) + // Default fudge = 10 periods → OK until 10. + string farFutureCode = auth.GetTOTP(now.AddSeconds(11 * 30)); + + Assert.IsFalse(auth.IsTOTPValid(farFutureCode)); + } + + [TestMethod] + public void ShouldSupportSHA256() + { + var keyUri = new AuthenticatorKeyUri( + "totp", + "Corp", + "user", + secret: RfcBase32Secret, + algorithm: "SHA256", + digits: 6, + period: 30); + + var auth = new Authenticator(keyUri); + + string code = auth.GetTOTP(new DateTime(2022, 1, 1, 0, 0, 0, DateTimeKind.Utc)); + + Assert.AreEqual(6, code.Length); + Assert.IsTrue(int.TryParse(code, out _), "Expected numeric TOTP"); + } + + [TestMethod] + public void ShouldSupportSHA512() + { + var keyUri = new AuthenticatorKeyUri( + "totp", + "Corp", + "user", + secret: RfcBase32Secret, + algorithm: "SHA512", + digits: 8, + period: 30); + + var auth = new Authenticator(keyUri); + + string code = auth.GetTOTP(new DateTime(2023, 1, 1, 0, 0, 0, DateTimeKind.Utc)); + + Assert.AreEqual(8, code.Length); + Assert.IsTrue(int.TryParse(code, out _)); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary.Tests.csproj b/TechnitiumLibrary.Tests/TechnitiumLibrary.Tests.csproj new file mode 100644 index 0000000..2b9c96c --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary.Tests.csproj @@ -0,0 +1,26 @@ + + + + net9.0 + latest + disable + enable + true + + + + + + + + + + + + + + ..\TechnitiumLibrary.Net.Firewall\obj\Debug\Interop.NetFwTypeLib.dll + + + + diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/Base32Tests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/Base32Tests.cs new file mode 100644 index 0000000..6a37cbe --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/Base32Tests.cs @@ -0,0 +1,259 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Text; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + [TestClass] + public class Base32Tests + { + // RFC vectors for Base32 + private static readonly (string clear, string enc)[] RfcVectors = + { + ("f", "MY======"), + ("fo", "MZXQ===="), + ("foo", "MZXW6==="), + ("foob", "MZXW6YQ="), + ("fooba", "MZXW6YTB"), + ("foobar", "MZXW6YTBOI======") + }; + + // Values that must decode and encode back identically + private static readonly string[] RoundTripValues = + { + "", "10", "test130", "test", "8", "0", "=", "foobar" + }; + + // Arbitrary real-world binary sample from PHP tests + private static readonly byte[] RandomBytes = + Convert.FromBase64String("HgxBl1kJ4souh+ELRIHm/x8yTc/cgjDmiCNyJR/NJfs="); + + + // -------------------- RFC vectors -------------------- + + [TestMethod] + public void ToBase32String_RfcVectors_ProduceExpectedOutput() + { + foreach (var (clear, encoded) in RfcVectors) + { + // Arrange + var data = Encoding.ASCII.GetBytes(clear); + + // Act + var result = Base32.ToBase32String(data); + + // Assert + Assert.AreEqual(encoded, result, "Base32 encoding must match RFC vectors."); + } + } + + [TestMethod] + public void FromBase32String_RfcVectors_DecodeCorrectly() + { + foreach (var (clear, encoded) in RfcVectors) + { + // Arrange + var expected = Encoding.ASCII.GetBytes(clear); + + // Act + var result = Base32.FromBase32String(encoded); + + // Assert + CollectionAssert.AreEqual(expected, result, "Decoding must invert RFC vectors."); + } + } + + + // -------------------- RandomBytes encoding/decoding -------------------- + + [TestMethod] + public void ToBase32String_RandomBytes_MatchesExpectedEncoding() + { + // Given test fixture from PHP + const string expected = "DYGEDF2ZBHRMULUH4EFUJAPG74PTETOP3SBDBZUIENZCKH6NEX5Q===="; + + // Act + var actual = Base32.ToBase32String(RandomBytes); + + Assert.AreEqual(expected, actual, "Binary encoding must be stable and deterministic."); + } + + [TestMethod] + public void FromBase32String_RandomBytes_ReturnsOriginalInput() + { + // Arrange + const string encoded = "DYGEDF2ZBHRMULUH4EFUJAPG74PTETOP3SBDBZUIENZCKH6NEX5Q===="; + + // Act + var decoded = Base32.FromBase32String(encoded); + + // Assert + CollectionAssert.AreEqual(RandomBytes, decoded); + } + + + // -------------------- General encode/decode identity tests -------------------- + + [TestMethod] + public void EncodeDecode_RoundTrip_GivenKnownClearInputs_ReturnsOriginalValues() + { + foreach (var clear in RoundTripValues) + { + // Arrange + var bytes = Encoding.UTF8.GetBytes(clear); + + // Act + var encoded = Base32.ToBase32String(bytes); + var decoded = Base32.FromBase32String(encoded); + + // Assert + var decodedText = Encoding.UTF8.GetString(decoded); + Assert.AreEqual(clear, decodedText, "Encode + decode must round-trip."); + } + } + + + // -------------------- Explicit edge case tests -------------------- + + [TestMethod] + public void FromBase32String_GivenEmptyString_ReturnsEmptyArray() + { + var result = Base32.FromBase32String(""); + Assert.IsEmpty(result); + } + + [TestMethod] + public void ToBase32String_GivenEmptyBytes_ReturnsEmptyString() + { + var result = Base32.ToBase32String(Array.Empty()); + Assert.IsEmpty(result); + } + + [TestMethod] + public void FromBase32String_GivenNullString_ThrowsException() + { + Assert.ThrowsExactly(() => Base32.FromBase32String(null)); + + } + + [TestMethod] + public void FromBase32HexString_GivenNullString_ThrowsException() + { + Assert.ThrowsExactly(() => Base32.FromBase32HexString(null)); + + } + + [TestMethod] + public void FromBase32String_GivenStringWithSpace_ThrowsException() + { + Assert.ThrowsExactly(() => Base32.FromBase32String("MZXW6YTBOI====== ")); + + } + + [TestMethod] + public void FromBase32HexString_GivenNullStringSpace_ThrowsException() + { + Assert.ThrowsExactly(() => Base32.FromBase32HexString("MZXW6YTBOI====== ")); + } + } + + [TestClass] + public class Base32HexTests + { + private static readonly (string clear, string enc)[] RfcVectors = + { + ("f", "CO======"), + ("fo", "CPNG===="), + ("foo", "CPNMU==="), + ("foob", "CPNMUOG="), + ("fooba", "CPNMUOJ1"), + ("foobar", "CPNMUOJ1E8======"), + }; + + private static readonly string[] RoundTripValues = + { + "", "10", "test130", "test", "8", "0", "=", "foobar" + }; + + private static readonly byte[] RandomBytes = + Convert.FromBase64String("HgxBl1kJ4souh+ELRIHm/x8yTc/cgjDmiCNyJR/NJfs="); + + + // ---------------- RFC vectors ---------------- + + [TestMethod] + public void ToBase32HexString_RfcVectors_ProduceExpectedOutput() + { + foreach (var (clear, encoded) in RfcVectors) + { + var data = Encoding.ASCII.GetBytes(clear); + var result = Base32.ToBase32HexString(data); + Assert.AreEqual(encoded, result, "Hex encoding must match RFC vectors."); + } + } + + [TestMethod] + public void FromBase32HexString_RfcVectors_DecodeCorrectly() + { + foreach (var (clear, encoded) in RfcVectors) + { + var expected = Encoding.ASCII.GetBytes(clear); + var result = Base32.FromBase32HexString(encoded); + CollectionAssert.AreEqual(expected, result); + } + } + + + // ---------------- Known binary test ---------------- + + [TestMethod] + public void ToBase32HexString_RandomBytes_MatchesExpectedEncoding() + { + const string expected = "3O6435QP17HCKBK7S45K90F6VSFJ4JEFRI131PK84DP2A7UD4NTG===="; + var result = Base32.ToBase32HexString(RandomBytes); + Assert.AreEqual(expected, result); + } + + [TestMethod] + public void FromBase32HexString_RandomBytes_ReturnsOriginalInput() + { + const string encoded = "3O6435QP17HCKBK7S45K90F6VSFJ4JEFRI131PK84DP2A7UD4NTG===="; + var decoded = Base32.FromBase32HexString(encoded); + CollectionAssert.AreEqual(RandomBytes, decoded); + } + + + // ---------------- Roundtrip tests ---------------- + + [TestMethod] + public void EncodeDecode_RoundTrip_GivenKnownClearInputs_ReturnsOriginal() + { + foreach (var clear in RoundTripValues) + { + var bytes = Encoding.UTF8.GetBytes(clear); + var encoded = Base32.ToBase32HexString(bytes); + var decodedBytes = Base32.FromBase32HexString(encoded); + var decoded = Encoding.UTF8.GetString(decodedBytes); + + Assert.AreEqual(clear, decoded); + } + } + + + // ---------------- Explicit empty edge cases ---------------- + + [TestMethod] + public void FromBase32HexString_GivenEmpty_ReturnsEmptyArray() + { + var result = Base32.FromBase32HexString(""); + Assert.IsEmpty(result); + } + + [TestMethod] + public void ToBase32HexString_GivenEmptyBytes_ReturnsEmptyString() + { + var result = Base32.ToBase32HexString(Array.Empty()); + Assert.IsEmpty(result); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/BinaryNumberTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/BinaryNumberTests.cs new file mode 100644 index 0000000..2b9a352 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/BinaryNumberTests.cs @@ -0,0 +1,405 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + [TestClass] + public class BinaryNumberTests + { + private static byte[] Bytes(params byte[] v) => v; + + // --------------------------------------------------------------------- + // Constructor tests + // --------------------------------------------------------------------- + + [TestMethod] + public void Constructor_ShouldStoreReferenceValue() + { + // GIVEN + var raw = Bytes(0xAA, 0xBB); + + // WHEN + var bn = new BinaryNumber(raw); + + // THEN + CollectionAssert.AreEqual(raw, bn.Value); + } + + [TestMethod] + public void Constructor_ShouldCreateFromBinaryReader_WhenValidLengthIsGiven() + { + // GIVEN + using MemoryStream ms = new(); + using BinaryWriter bw = new(ms); + bw.Write7BitEncodedInt(3); + bw.Write(Bytes(0x11, 0x22, 0x33)); + ms.Position = 0; + + // WHEN + using BinaryReader br = new(ms); + var bn = new BinaryNumber(br); + + // THEN + CollectionAssert.AreEqual(Bytes(0x11, 0x22, 0x33), bn.Value); + } + + [TestMethod] + public void Constructor_ShouldThrow_WhenStreamHasInsufficientBytes() + { + // GIVEN + using MemoryStream ms = new(); + using BinaryWriter bw = new(ms); + bw.Write7BitEncodedInt(5); // claims 5 bytes exist + bw.Write(Bytes(0xAA)); // only 1 byte actually written + ms.Position = 0; + + using BinaryReader br = new(ms); + + // WHEN + THEN + Assert.ThrowsExactly(() => new BinaryNumber(br)); + } + + [TestMethod] + public void Constructor_ShouldThrow_WhenStreamIsUnreadable() + { + // GIVEN + var unreadableStream = new UnreadableStream(); + + // WHEN + THEN + Assert.ThrowsExactly(() => + { + using var reader = new BinaryReader(unreadableStream); + _ = new BinaryNumber(reader); // will not be reached + }); + } + + // --------------------------------------------------------------------- + // Clone tests + // --------------------------------------------------------------------- + + [TestMethod] + public void Clone_ShouldReturnNewInstanceWithSameBytes() + { + // GIVEN + var bn = new BinaryNumber(Bytes(0x10, 0x20)); + + // WHEN + var clone = bn.Clone(); + + // THEN + Assert.AreNotSame(bn.Value, clone.Value); + CollectionAssert.AreEqual(bn.Value, clone.Value); + } + + // --------------------------------------------------------------------- + // Parse tests + // --------------------------------------------------------------------- + + [TestMethod] + public void Parse_ShouldDecodeHexString() + { + // GIVEN + var hex = "A1B2C3"; + + // WHEN + var bn = BinaryNumber.Parse(hex); + + // THEN + CollectionAssert.AreEqual(Bytes(0xA1, 0xB2, 0xC3), bn.Value); + } + + [TestMethod] + public void Parse_ShouldThrow_WhenStringContainsInvalidHex() + { + // GIVEN + var badHex = "XYZ123"; + + // WHEN + THEN + Assert.ThrowsExactly(() => BinaryNumber.Parse(badHex)); + } + + [TestMethod] + public void Parse_ShouldThrow_WhenInputIsNull() + { + // GIVEN + string input = null; + + // THEN + Assert.ThrowsExactly(() => BinaryNumber.Parse(input)); + } + + // --------------------------------------------------------------------- + // Static Equals(byte[], byte[]) tests + // --------------------------------------------------------------------- + + [TestMethod] + public void StaticEquals_ShouldReturnTrue_WhenBothNull() + { + // GIVEN + byte[] a = null; + byte[] b = null; + + // WHEN + THEN + Assert.IsTrue(BinaryNumber.Equals(a, b)); + } + + [TestMethod] + public void StaticEquals_ShouldReturnFalse_WhenOneSideIsNull() + { + // GIVEN + byte[] a = Bytes(1, 2, 3); + byte[] b = null; + + // WHEN + THEN + Assert.IsFalse(BinaryNumber.Equals(a, b)); + } + + [TestMethod] + public void StaticEquals_ShouldReturnFalse_WhenLengthsDiffer() + { + // GIVEN + byte[] a = Bytes(1, 2); + byte[] b = Bytes(1, 2, 3); + + // WHEN + THEN + Assert.IsFalse(BinaryNumber.Equals(a, b)); + } + + [TestMethod] + public void StaticEquals_ShouldReturnFalse_WhenContentDiffers() + { + // GIVEN + byte[] a = Bytes(1, 2, 3); + byte[] b = Bytes(1, 9, 3); + + // WHEN + THEN + Assert.IsFalse(BinaryNumber.Equals(a, b)); + } + + // --------------------------------------------------------------------- + // IEquatable tests + // --------------------------------------------------------------------- + + [TestMethod] + public void Equals_ShouldReturnFalse_WhenOtherIsNull() + { + // GIVEN + var bn = new BinaryNumber(Bytes(1, 2)); + + // WHEN + THEN + Assert.IsFalse(bn.Equals(null)); + } + + [TestMethod] + public void EqualsObject_ShouldReturnFalse_ForIncorrectType() + { + // GIVEN + var bn = new BinaryNumber(Bytes(1)); + + // WHEN + THEN + Assert.IsFalse(bn.Equals(new object())); + } + + [TestMethod] + public void Equals_ShouldReturnTrue_ForIdenticalValues() + { + var a = new BinaryNumber(Bytes(0xAA, 0xBB)); + var b = new BinaryNumber(Bytes(0xAA, 0xBB)); + + Assert.IsTrue(a.Equals(b)); + } + + // --------------------------------------------------------------------- + // CompareTo tests + // --------------------------------------------------------------------- + + [TestMethod] + public void CompareTo_ShouldThrow_WhenLengthsDiffer() + { + // GIVEN + var a = new BinaryNumber(Bytes(0x01)); + var b = new BinaryNumber(Bytes(0x02, 0x03)); + + // WHEN + THEN + Assert.ThrowsExactly(() => a.CompareTo(b)); + } + + [TestMethod] + public void CompareTo_ShouldReturnZero_WhenValuesMatch() + { + var a = new BinaryNumber(Bytes(1, 2)); + var b = new BinaryNumber(Bytes(1, 2)); + + Assert.AreEqual(0, a.CompareTo(b)); + } + + [TestMethod] + public void CompareTo_ShouldReturnPositive_WhenAIsGreater() + { + var a = new BinaryNumber(Bytes(0xFF)); + var b = new BinaryNumber(Bytes(0x00)); + + Assert.AreEqual(1, a.CompareTo(b)); + } + + [TestMethod] + public void CompareTo_ShouldReturnNegative_WhenAIsSmaller() + { + var a = new BinaryNumber(Bytes(0x00)); + var b = new BinaryNumber(Bytes(0xFF)); + + Assert.AreEqual(-1, a.CompareTo(b)); + } + + // --------------------------------------------------------------------- + // Operator tests + // --------------------------------------------------------------------- + + [TestMethod] + public void OperatorEquality_ShouldBeTrue_ForSameReference() + { + var bn = new BinaryNumber(Bytes(1, 2)); +#pragma warning disable CS1718 // Comparison made to same variable + Assert.IsTrue(bn == bn); +#pragma warning restore CS1718 // Comparison made to same variable + } + + [TestMethod] + public void OperatorInequality_ShouldBeTrue_WhenValuesDiffer() + { + var a = new BinaryNumber(Bytes(1, 2)); + var b = new BinaryNumber(Bytes(9, 9)); + Assert.IsTrue(a != b); + } + + [TestMethod] + public void OperatorOr_ShouldThrow_WhenLengthsDiffer() + { + var a = new BinaryNumber(Bytes(1)); + var b = new BinaryNumber(Bytes(1, 2)); + + Assert.ThrowsExactly(() => _ = a | b); + } + + [TestMethod] + public void OperatorOr_ShouldReturnCorrectValue() + { + var a = new BinaryNumber(Bytes(0b00001111)); + var b = new BinaryNumber(Bytes(0b11110000)); + + var result = a | b; + + CollectionAssert.AreEqual(Bytes(0b11111111), result.Value); + } + + [TestMethod] + public void OperatorAnd_ShouldReturnCorrectValue() + { + var a = new BinaryNumber(Bytes(0x0F)); + var b = new BinaryNumber(Bytes(0xF0)); + + var result = a & b; + + CollectionAssert.AreEqual(Bytes(0x00), result.Value); + } + + [TestMethod] + public void OperatorXor_ShouldReturnCorrectValue() + { + var a = new BinaryNumber(Bytes(0xAA)); + var b = new BinaryNumber(Bytes(0xFF)); + + var result = a ^ b; + + CollectionAssert.AreEqual(Bytes(0x55), result.Value); + } + + [TestMethod] + public void OperatorShiftLeft_ShouldShiftBits() + { + var src = new BinaryNumber(Bytes(0b00000001, 0b00000000)); + var shifted = src << 1; + + CollectionAssert.AreEqual(Bytes(0b00000010, 0b00000000), shifted.Value); + } + + [TestMethod] + public void OperatorShiftRight_ShouldShiftBits() + { + var src = new BinaryNumber(Bytes(0b00000100, 0b00000000)); + var shifted = src >> 2; + + CollectionAssert.AreEqual(Bytes(0b00000001, 0b00000000), shifted.Value); + } + + [TestMethod] + public void OperatorNot_ShouldInvertBits() + { + var src = new BinaryNumber(Bytes(0x00, 0xFF)); + var inv = ~src; + + CollectionAssert.AreEqual(Bytes(0xFF, 0x00), inv.Value); + } + + [TestMethod] + public void ComparisonOperators_ShouldHonorLexicographicOrder() + { + var a = new BinaryNumber(Bytes(1, 2)); + var b = new BinaryNumber(Bytes(9, 9)); + + Assert.IsTrue(a < b); + Assert.IsTrue(b > a); + Assert.IsTrue(a <= b); + Assert.IsTrue(b >= a); + } + + [TestMethod] + public void ComparisonOperators_ShouldThrow_WhenLengthsDiffer() + { + var a = new BinaryNumber(Bytes(1)); + var b = new BinaryNumber(Bytes(1, 2)); + + Assert.ThrowsExactly(() => _ = a < b); + Assert.ThrowsExactly(() => _ = a > b); + Assert.ThrowsExactly(() => _ = a <= b); + Assert.ThrowsExactly(() => _ = a >= b); + } + + // --------------------------------------------------------------------- + // WriteTo tests + // --------------------------------------------------------------------- + + [TestMethod] + public void WriteTo_ShouldWritePrefixAndBytes() + { + var bn = new BinaryNumber(Bytes(0x11, 0x22)); + using MemoryStream ms = new(); + + bn.WriteTo(ms); + + var result = ms.ToArray(); + Assert.HasCount(3, result); + Assert.AreEqual(2, result[0]); // length prefix + CollectionAssert.AreEqual(Bytes(0x11, 0x22), result[1..]); + } + + // --------------------------------------------------------------------- + // Supporting Test Doubles + // --------------------------------------------------------------------- + + private sealed class UnreadableStream : Stream + { + public override bool CanRead => false; + public override bool CanSeek => false; + public override bool CanWrite => false; + public override long Length => throw new NotSupportedException(); + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + public override void Flush() => throw new NotSupportedException(); + public override int Read(byte[] buffer, int offset, int count) => throw new IOException("Unreadable"); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/CollectionExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/CollectionExtensionsTests.cs new file mode 100644 index 0000000..cf42ae6 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/CollectionExtensionsTests.cs @@ -0,0 +1,262 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + [TestClass] + public sealed class CollectionExtensionsTests + { + // ------------------------------------------------------------- + // Shuffle + // ------------------------------------------------------------- + + [TestMethod] + public void Shuffle_ShouldRearrangeItems_WhenListHasMultipleElements() + { + // GIVEN + var input = new[] { 1, 2, 3, 4, 5 }; + var original = input.ToArray(); + + // WHEN + input.Shuffle(); + + // THEN + Assert.HasCount(original.Length, input, "Shuffle must not remove items."); + Assert.IsTrue(input.All(original.Contains), "Shuffle must retain all original items."); + } + + [TestMethod] + public void Shuffle_ShouldNotChangeSingleElementList() + { + // GIVEN + var input = new[] { 42 }; + + // WHEN + input.Shuffle(); + + // THEN + Assert.AreEqual(42, input[0]); + } + + [TestMethod] + public void Shuffle_ShouldNotThrow_WhenEmpty() + { + // GIVEN + var input = Array.Empty(); + + // WHEN + input.Shuffle(); + + // THEN + Assert.IsEmpty(input); + } + + // ------------------------------------------------------------- + // Convert (IReadOnlyList) + // ------------------------------------------------------------- + + [TestMethod] + public void Convert_List_ShouldTransformElements() + { + // GIVEN + IReadOnlyList input = new ReadOnlyCollection(new[] { 1, 2, 3 }); + + // WHEN + var result = input.Convert(x => x * 10); + + // THEN + Assert.HasCount(3, result); + Assert.AreEqual(10, result[0]); + Assert.AreEqual(20, result[1]); + Assert.AreEqual(30, result[2]); + } + + [TestMethod] + public void Convert_List_ShouldThrow_WhenArrayIsNull() + { + // GIVEN + IReadOnlyList? input = null; + + // WHEN + THEN + Assert.ThrowsExactly( + () => input.Convert(x => x * 10) + ); + } + + [TestMethod] + public void Convert_List_ShouldThrow_WhenConverterIsNull() + { + // GIVEN + IReadOnlyList input = Array.Empty(); + + // WHEN + THEN + Assert.ThrowsExactly( + () => input.Convert(null) + ); + } + + // ------------------------------------------------------------- + // Convert (IReadOnlyCollection) + // ------------------------------------------------------------- + + [TestMethod] + public void Convert_Collection_ShouldPreserveCount() + { + // GIVEN + IReadOnlyCollection input = new[] { "A", "BB", "CCC" }; + + // WHEN + var result = input.Convert(str => str.Length); + + // THEN + Assert.HasCount(3, result); + } + + [TestMethod] + public void Convert_Collection_ShouldThrow_WhenCollectionIsNull() + { + // GIVEN + IReadOnlyCollection input = null; + + // WHEN + THEN + Assert.ThrowsExactly( + () => input.Convert(x => x * 10) + ); + } + + [TestMethod] + public void Convert_Collection_ShouldThrow_WhenConverterIsNull() + { + // GIVEN + IReadOnlyCollection input = new[] { 1, 2 }; + + // WHEN + THEN + Assert.ThrowsExactly( + () => input.Convert(null) + ); + } + + // ------------------------------------------------------------- + // ListEquals + // ------------------------------------------------------------- + + [TestMethod] + public void ListEquals_ShouldReturnTrue_WhenSequencesMatchExactly() + { + // GIVEN + var a = new[] { 1, 2, 3 }; + var b = new[] { 1, 2, 3 }; + + // WHEN + var equal = a.ListEquals(b); + + // THEN + Assert.IsTrue(equal); + } + + [TestMethod] + public void ListEquals_ShouldReturnFalse_WhenLengthDiffers() + { + // GIVEN + var a = new[] { 1, 2 }; + var b = new[] { 1, 2, 3 }; + + // WHEN + var equal = a.ListEquals(b); + + // THEN + Assert.IsFalse(equal); + } + + [TestMethod] + public void ListEquals_ShouldReturnFalse_WhenElementDiffers() + { + // GIVEN + var a = new[] { 1, 2, 3 }; + var b = new[] { 1, 9, 3 }; + + // WHEN + var equal = a.ListEquals(b); + + // THEN + Assert.IsFalse(equal); + } + + [TestMethod] + public void ListEquals_ShouldReturnFalse_WhenSecondIsNull() + { + // GIVEN + var a = new[] { "X" }; + + // WHEN + var equal = a.ListEquals(null); + + // THEN + Assert.IsFalse(equal); + } + + // ------------------------------------------------------------- + // HasSameItems + // ------------------------------------------------------------- + + [TestMethod] + public void HasSameItems_ShouldReturnTrue_WhenSameElementsUnordered() + { + // GIVEN + var a = new[] { 3, 1, 2 }; + var b = new[] { 2, 3, 1 }; + + // WHEN + var equal = a.HasSameItems(b); + + // THEN + Assert.IsTrue(equal); + } + + [TestMethod] + public void HasSameItems_ShouldReturnFalse_WhenDifferentItemsPresent() + { + // GIVEN + var a = new[] { 1, 2, 3 }; + var b = new[] { 1, 2, 4 }; + + // WHEN + var equal = a.HasSameItems(b); + + // THEN + Assert.IsFalse(equal); + } + + // ------------------------------------------------------------- + // GetArrayHashCode + // ------------------------------------------------------------- + + [TestMethod] + public void GetArrayHashCode_ShouldReturnZero_WhenNull() + { + // WHEN + var hash = CollectionExtensions.GetArrayHashCode(null); + + // THEN + Assert.AreEqual(0, hash); + } + + [TestMethod] + public void GetArrayHashCode_ShouldMatchRegardlessOfOrder() + { + // GIVEN + var a = new[] { 10, 20, 30 }; + var b = new[] { 30, 10, 20 }; + + // WHEN + var hashA = a.GetArrayHashCode(); + var hashB = b.GetArrayHashCode(); + + // THEN + Assert.AreEqual(hashA, hashB, "XOR hash should not depend on order."); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/IndependentTaskSchedulerTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/IndependentTaskSchedulerTests.cs new file mode 100644 index 0000000..3b80ba8 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/IndependentTaskSchedulerTests.cs @@ -0,0 +1,146 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Threading; +using System.Threading.Tasks; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + [TestClass] + public sealed class IndependentTaskSchedulerTests + { + [TestMethod] + public async Task Task_ShouldExecute_WhenQueued() + { + // GIVEN + using var scheduler = new IndependentTaskScheduler(maximumConcurrencyLevel: 1); + var completion = new TaskCompletionSource(); + + // WHEN + var t = new Task(_ => completion.SetResult(true), null); + t.Start(scheduler); + + // THEN + Assert.IsTrue(await completion.Task); + } + + [TestMethod] + public void MaximumConcurrencyLevel_ShouldMatchRequested() + { + // GIVEN + using var scheduler = new IndependentTaskScheduler(3); + + // WHEN + var level = scheduler.MaximumConcurrencyLevel; + + // THEN + Assert.AreEqual(3, level); + } + + [TestMethod] + public async Task Tasks_ShouldRunInParallel_WhenConcurrencyGreaterThanOne() + { + // GIVEN + using var scheduler = new IndependentTaskScheduler(maximumConcurrencyLevel: 2); + var parallelStarted = new TaskCompletionSource(); + var runningCount = 0; + + Task Body() => + Task.Run(() => + { + if (Interlocked.Increment(ref runningCount) == 2) + { + parallelStarted.SetResult(true); + } + Thread.Sleep(40); + }); + + // WHEN + _ = Task.Factory.StartNew(() => Body(), CancellationToken.None, TaskCreationOptions.None, scheduler).Unwrap(); + _ = Task.Factory.StartNew(() => Body(), CancellationToken.None, TaskCreationOptions.None, scheduler).Unwrap(); + + // THEN + Assert.IsTrue(await parallelStarted.Task); + } + + [TestMethod] + public void LongRunningOption_ShouldExecuteOnDedicatedThread() + { + // GIVEN + using var scheduler = new IndependentTaskScheduler(1); + var factoryThreadId = Thread.CurrentThread.ManagedThreadId; + var schedulerThreadId = -1; + + // WHEN + var task = new Task( + _ => schedulerThreadId = Thread.CurrentThread.ManagedThreadId, + null, + TaskCreationOptions.LongRunning); + + task.Start(scheduler); + task.Wait(); + + // THEN + Assert.AreNotEqual(factoryThreadId, schedulerThreadId); + } + + [TestMethod] + public async Task InlineExecution_ShouldRun_WhenCalledInsideSchedulerThread() + { + // GIVEN + using var scheduler = new IndependentTaskScheduler(1); + + var executedInline = new TaskCompletionSource(); + + // WHEN + var driver = new Task(() => + { + // Attempt inline execution from scheduler thread + var child = new Task(() => executedInline.SetResult(true)); + // This will execute inline because we are already inside scheduler thread + child.RunSynchronously(TaskScheduler.Current); + }); + + // Run the driver task inside scheduler + driver.Start(scheduler); + await driver; + + // THEN + Assert.IsTrue(await executedInline.Task, "Task must execute inline in scheduler thread."); + } + + [TestMethod] + public void Dispose_ShouldPreventFutureExecution() + { + // GIVEN + var scheduler = new IndependentTaskScheduler(1); + scheduler.Dispose(); + var task = new Task(() => { }); + + // WHEN + var continuation = Task.Factory.StartNew( + () => task.Start(scheduler), + CancellationToken.None, + TaskCreationOptions.None, + TaskScheduler.Default); + + continuation.Wait(); + + // THEN + Assert.IsFalse(task.IsCompleted); + } + + [TestMethod] + [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "MSTEST0032:Assertion condition is always true", Justification = "Double Dispose must not throw")] + public void Dispose_CanBeCalledMultipleTimes_Safely() + { + // GIVEN + var scheduler = new IndependentTaskScheduler(); + + // WHEN + scheduler.Dispose(); + scheduler.Dispose(); + + // THEN + Assert.IsTrue(true); // simply must not throw + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/JsonExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/JsonExtensionsTests.cs new file mode 100644 index 0000000..2233a0e --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/JsonExtensionsTests.cs @@ -0,0 +1,292 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Linq; +using System.Text.Json; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + [TestClass] + public sealed class JsonExtensionsTests + { + private static JsonElement ToElement(string json) + { + using var doc = JsonDocument.Parse(json); + return doc.RootElement.Clone(); + } + + // ------------------------------ + // ARRAY READING (STRING) + // ------------------------------ + + [TestMethod] + public void GetArray_ShouldReturnStringArray_WhenArrayExists() + { + // GIVEN + var json = ToElement("""{ "values": ["a", "b", "c"] }"""); + + // WHEN + var result = json.ReadArray("values"); + + // THEN + Assert.HasCount(3, result); + Assert.AreEqual("a", result[0]); + Assert.AreEqual("b", result[1]); + Assert.AreEqual("c", result[2]); + } + + [TestMethod] + public void GetArray_ShouldReturnNull_WhenJsonContainsNull() + { + // GIVEN + var json = ToElement("""{ "values": null }"""); + + // WHEN + var result = json.ReadArray("values"); + + // THEN + Assert.IsNull(result); + } + + [TestMethod] + public void GetArray_ShouldThrow_WhenPropertyIsNotArrayOrNull() + { + // GIVEN + var json = ToElement("""{ "values": 123 }"""); + + // WHEN–THEN + Assert.ThrowsExactly(() => json.ReadArray("values")); + } + + // ------------------------------ + // ARRAY READING WITH MAPPING (string→int) + // ------------------------------ + + [TestMethod] + public void ReadArray_WithConverter_ShouldReturnMappedArray() + { + // GIVEN + var json = ToElement("""{ "values": ["1","2","3"] }"""); + + // WHEN + var result = json.ReadArray("values", int.Parse); + + // THEN + Assert.HasCount(3, result); + Assert.AreEqual(1, result[0]); + Assert.AreEqual(2, result[1]); + Assert.AreEqual(3, result[2]); + } + + [TestMethod] + public void ReadArray_WithConverter_ShouldThrow_WhenConverterThrows() + { + // GIVEN + var json = ToElement("""{ "values": ["bad"] }"""); + + // WHEN–THEN + Assert.ThrowsExactly(() => + json.ReadArray("values", s => int.Parse(s))); + } + + [TestMethod] + public void TryReadArray_WithConverter_ShouldReturnFalse_WhenPropertyMissing() + { + // GIVEN + var json = ToElement("""{ "other": [1,2] }"""); + + // WHEN + var result = json.TryReadArray("values", int.Parse, out var array); + + // THEN + Assert.IsFalse(result); + Assert.IsNull(array); + } + + [TestMethod] + public void TryReadArray_WithConverter_ShouldReturnTrue_WhenArrayExists() + { + // GIVEN + var json = ToElement("""{ "values": ["10","20"] }"""); + + // WHEN + var result = json.TryReadArray("values", int.Parse, out var array); + + // THEN + Assert.IsTrue(result); + Assert.HasCount(2, array); + Assert.AreEqual(10, array[0]); + Assert.AreEqual(20, array[1]); + } + + // ------------------------------ + // READ SET + // ------------------------------ + + [TestMethod] + public void ReadArrayAsSet_ShouldReturnHashSetOfUniqueValues() + { + // GIVEN + var json = ToElement("""{ "values": ["a","b","a"] }"""); + + // WHEN + var result = json.ReadArrayAsSet("values"); + + // THEN + Assert.HasCount(2, result); + Assert.Contains("a", result); + Assert.Contains("b", result); + } + + [TestMethod] + public void TryReadArrayAsSet_ShouldReturnFalse_WhenNoProperty() + { + // GIVEN + var json = ToElement("""{ "other": [] }"""); + + // WHEN + var result = json.TryReadArrayAsSet("values", out var set); + + // THEN + Assert.IsFalse(result); + Assert.IsNull(set); + } + + // ------------------------------ + // MAP READING + // ------------------------------ + + [TestMethod] + public void ReadArrayAsMap_ShouldReturnDictionary_WhenMappingReturnsPairs() + { + // GIVEN + var json = ToElement("""{ "values": [ { "k":"x","v":"1" }, { "k":"y","v":"2"} ] }"""); + + // WHEN + var result = json.ReadArrayAsMap("values", el => + { + var key = el.GetProperty("k").GetString(); +#pragma warning disable CS8604 // Possible null reference argument. + var val = int.Parse(el.GetProperty("v").GetString()); +#pragma warning restore CS8604 // Possible null reference argument. + return Tuple.Create(key, val); + }); + + // THEN + Assert.HasCount(2, result); + Assert.AreEqual(1, result["x"]); + Assert.AreEqual(2, result["y"]); + } + + [TestMethod] + public void TryReadArrayAsMap_ShouldReturnFalse_WhenPropertyMissing() + { + // GIVEN + var json = ToElement("""{ "other": [] }"""); + + // WHEN + var result = json.TryReadArrayAsMap("values", _ => null, out var map); + + // THEN + Assert.IsFalse(result); + Assert.IsNull(map); + } + + [TestMethod] + public void ReadArrayAsMap_ShouldIgnoreNullReturnedPairs() + { + var json = ToElement(""" + { "arr": [123, 456] } + """); + + var result = json.ReadArrayAsMap("arr", _ => null); + + Assert.IsEmpty(result); + } + + + // ------------------------------ + // GET PROPERTY VALUE + // ------------------------------ + + [TestMethod] + public void GetPropertyValue_String_ShouldReturnDefault_WhenMissing() + { + // GIVEN + var json = ToElement("""{ "name": "test" }"""); + + // WHEN + var value = json.GetPropertyValue("missing", "default"); + + // THEN + Assert.AreEqual("default", value); + } + + [TestMethod] + public void GetPropertyValue_Int_ShouldReturnStoredValue() + { + // GIVEN + var json = ToElement("""{ "value": 42 }"""); + + // WHEN + var value = json.GetPropertyValue("value", -1); + + // THEN + Assert.AreEqual(42, value); + } + + [TestMethod] + public void GetPropertyEnumValue_ShouldReturnEnum() + { + // GIVEN + var json = ToElement("""{ "mode": "Friday" }"""); + + // WHEN + var result = json.GetPropertyEnumValue("mode", DayOfWeek.Monday); + + // THEN + Assert.AreEqual(DayOfWeek.Friday, result); + } + + [TestMethod] + public void GetPropertyEnumValue_ShouldReturnDefault_WhenNotFound() + { + // GIVEN + var json = ToElement("""{ "val": 10 }"""); + + // WHEN + var result = json.GetPropertyEnumValue("missing", DayOfWeek.Sunday); + + // THEN + Assert.AreEqual(DayOfWeek.Sunday, result); + } + + // ------------------------------ + // WRITE ARRAY + // ------------------------------ + + [TestMethod] + public void WriteStringArray_ShouldSerializeStrings_AsJsonArray() + { + // GIVEN + var input = new[] { "x", "y", "z" }; + using var buffer = new System.IO.MemoryStream(); + using var writer = new Utf8JsonWriter(buffer); + + // WHEN + writer.WriteStartObject(); + writer.WriteStringArray("values", input); + writer.WriteEndObject(); + writer.Flush(); + + var json = JsonDocument.Parse(buffer.ToArray()).RootElement; + + // THEN + var arr = json.GetProperty("values").EnumerateArray().Select(x => x.GetString()).ToArray(); + + Assert.HasCount(3, arr); + Assert.AreEqual("x", arr[0]); + Assert.AreEqual("y", arr[1]); + Assert.AreEqual("z", arr[2]); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/StringExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/StringExtensionsTests.cs new file mode 100644 index 0000000..ea8898e --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/StringExtensionsTests.cs @@ -0,0 +1,184 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + [TestClass] + public sealed class StringExtensionsTests + { + // ----------------------------- + // Split + // ----------------------------- + + [TestMethod] + public void Split_ShouldConvertItems_WhenParsingSucceeds() + { + // GIVEN + const string input = "1, 2, 3"; + + // WHEN + var result = input.Split(int.Parse, ','); + + // THEN + Assert.HasCount(3, result); + Assert.AreEqual(1, result[0]); + Assert.AreEqual(2, result[1]); + Assert.AreEqual(3, result[2]); + } + + [TestMethod] + public void Split_ShouldRemoveEmptyEntries_AndTrim() + { + // GIVEN + const string input = " 10 ; ; 20 ; 30 "; + + // WHEN + var result = input.Split(int.Parse, ';'); + + // THEN + Assert.HasCount(3, result); + Assert.AreEqual(10, result[0]); + Assert.AreEqual(20, result[1]); + Assert.AreEqual(30, result[2]); + } + + [TestMethod] + public void Split_ShouldThrow_WhenParserThrows() + { + // GIVEN + const string input = "10, BAD"; + + // WHEN–THEN + Assert.ThrowsExactly(() => _ = input.Split(int.Parse, ',')); + } + + [TestMethod] + public void Split_ShouldThrow_WhenStringIsNull() + { + // GIVEN + const string? input = null; + + // WHEN–THEN + Assert.ThrowsExactly(() => + _ = input.Split(int.Parse, ',')); + } + + // ----------------------------- + // Join + // ----------------------------- + + [TestMethod] + public void Join_ShouldReturnCommaSeparatedValues() + { + // GIVEN + var input = new[] { 1, 2, 3 }; + + // WHEN + var result = input.Join(','); + + // THEN + Assert.AreEqual("1, 2, 3", result); + } + + [TestMethod] + public void Join_ShouldReturnNull_WhenCollectionEmpty() + { + // GIVEN + var input = Array.Empty(); + + // WHEN + var result = input.Join(','); + + // THEN + Assert.IsNull(result); + } + + [TestMethod] + public void Join_ShouldThrow_WhenValuesIsNull() + { + // GIVEN + int[]? input = null; + + // WHEN–THEN + Assert.ThrowsExactly(() => input.Join(',')); + } + + // ----------------------------- + // ParseColonHexString + // ----------------------------- + + [TestMethod] + public void ParseColonHexString_ShouldReturnBytes_WhenValidHex() + { + // GIVEN + const string input = "0A:FF:01"; + + // WHEN + var result = input.ParseColonHexString(); + + // THEN + Assert.HasCount(3, result); + Assert.AreEqual(0x0A, result[0]); + Assert.AreEqual(0xFF, result[1]); + Assert.AreEqual(0x01, result[2]); + } + + [TestMethod] + public void ParseColonHexString_ShouldThrow_WhenInvalidHex() + { + // GIVEN + const string input = "GG:12"; + + // WHEN–THEN + Assert.ThrowsExactly(() => + _ = input.ParseColonHexString()); + } + + [TestMethod] + public void ParseColonHexString_ShouldThrow_WhenValueNotHex() + { + // GIVEN + const string input = "1K"; + + // WHEN–THEN + Assert.ThrowsExactly(() => + _ = input.ParseColonHexString()); + } + + [TestMethod] + public void ParseColonHexString_ShouldThrow_WhenInputContainsEmptySegments() + { + // GIVEN + const string input = "FF::AA"; + + // WHEN–THEN + Assert.ThrowsExactly(() => + _ = input.ParseColonHexString()); + } + + [TestMethod] + public void ParseColonHexString_ShouldThrow_WhenValueIsNull() + { + // GIVEN + const string? input = null; + + // WHEN–THEN + Assert.ThrowsExactly(() => + _ = input.ParseColonHexString()); + } + + [TestMethod] + public void ParseColonHexString_ShouldSupportSingleSegment() + { + // GIVEN + const string input = "FE"; + + // WHEN + var result = input.ParseColonHexString(); + + // THEN + Assert.HasCount(1, result); + Assert.AreEqual(0xFE, result[0]); + } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskExtensionsTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskExtensionsTests.cs new file mode 100644 index 0000000..f2175e0 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskExtensionsTests.cs @@ -0,0 +1,210 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + [TestClass] + public sealed class TaskExtensionsTests + { + // Helper allowing deterministic near-timeout simulation + private static Task NeverCompletes(CancellationToken _) => + new TaskCompletionSource().Task; + + // --------------------------------------------- + // TimeoutAsync (non-returning) + // --------------------------------------------- + + [TestMethod] + public async Task TimeoutAsync_ShouldComplete_WhenTaskFinishesBeforeTimeout() + { + // GIVEN + Func func = _ => Task.Delay(50, TestContext.CancellationToken); + + // WHEN-THEN + await TaskExtensions.TimeoutAsync(func, timeout: 500, TestContext.CancellationToken); + } + + [TestMethod] + public async Task TimeoutAsync_ShouldThrowTimeoutException_WhenOperationExceedsTimeout() + { + // GIVEN + Func func = NeverCompletes; + + // WHEN-THEN + await Assert.ThrowsExactlyAsync(() => + TaskExtensions.TimeoutAsync(func, timeout: 50, TestContext.CancellationToken)); + } + + [TestMethod] + public async Task TimeoutAsync_ShouldThrowOriginalException_WhenTaskFails() + { + // GIVEN + Func func = _ => throw new InvalidOperationException("boom"); + + // WHEN-THEN + await Assert.ThrowsExactlyAsync(() => + TaskExtensions.TimeoutAsync(func, timeout: 500, TestContext.CancellationToken)); + } + + [TestMethod] + public async Task TimeoutAsync_ShouldThrowOperationCanceled_WhenRootTokenCancelled() + { + // GIVEN + using var cts = new CancellationTokenSource(); + Func func = NeverCompletes; + + // WHEN + await cts.CancelAsync(); + + // THEN + await Assert.ThrowsExactlyAsync(() => + TaskExtensions.TimeoutAsync(func, timeout: 200, cancellationToken: cts.Token)); + } + + // --------------------------------------------- + // TimeoutAsync (generic) + // --------------------------------------------- + + [TestMethod] + public async Task TimeoutAsync_Generic_ShouldReturnValue_WhenCompletedWithinTimeout() + { + // GIVEN + Func> func = _ => Task.FromResult(42); + + // WHEN + var result = await TaskExtensions.TimeoutAsync(func, timeout: 300, TestContext.CancellationToken); + + // THEN + Assert.AreEqual(42, result); + } + + [TestMethod] + public async Task TimeoutAsync_Generic_ShouldThrowTimeoutException_WhenTaskRunsTooLong() + { + // GIVEN + Func> func = async _ => + { + await Task.Delay(2000, TestContext.CancellationToken); + return 5; + }; + + // WHEN-THEN + await Assert.ThrowsExactlyAsync(() => + TaskExtensions.TimeoutAsync(func, timeout: 50, TestContext.CancellationToken)); + } + + [TestMethod] + public async Task TimeoutAsync_Generic_ShouldPropagateSourceException() + { + // GIVEN + Func> func = + _ => throw new FormatException("fail"); + + // WHEN-THEN + await Assert.ThrowsExactlyAsync(() => + TaskExtensions.TimeoutAsync(func, timeout: 500, TestContext.CancellationToken)); + } + + // --------------------------------------------- + // Sync() Task + // --------------------------------------------- + + [TestMethod] + public void Sync_ShouldBlockUntilCompleted() + { + // GIVEN + var task = Task.Delay(50, TestContext.CancellationToken); + + // WHEN-THEN + task.Sync(); + } + + [TestMethod] + public void Sync_ShouldRethrowOriginalException() + { + // GIVEN + var task = Task.FromException(new InvalidOperationException("bad")); + + // WHEN-THEN + Assert.ThrowsExactly(() => task.Sync()); + } + + [TestMethod] + public void Sync_ShouldThrowNullReference_WhenTaskIsNull() + { + // GIVEN + Task? task = null; + + // WHEN-THEN + Assert.ThrowsExactly(() => task!.Sync()); + } + + // --------------------------------------------- + // Sync() Task + // --------------------------------------------- + + [TestMethod] + public void Sync_Generic_ShouldReturnValue() + { + // GIVEN + var task = Task.FromResult(123); + + // WHEN + var result = task.Sync(); + + // THEN + Assert.AreEqual(123, result); + } + + [TestMethod] + public void Sync_Generic_ShouldSurfaceException() + { + // GIVEN + var task = Task.FromException(new FormatException()); + + // WHEN-THEN + Assert.ThrowsExactly(() => task.Sync()); + } + + [TestMethod] + public void Sync_Generic_ShouldThrowOnNullTask() + { + // GIVEN + Task? task = null; + + // WHEN-THEN + Assert.ThrowsExactly(() => task!.Sync()); + } + + // --------------------------------------------- + // Sync() ValueTask / ValueTask + // --------------------------------------------- + + [TestMethod] + public void Sync_ValueTask_ShouldBlockUntilCompletion() + { + // GIVEN + var vt = new ValueTask(Task.Delay(50, TestContext.CancellationToken)); + + // WHEN-THEN + vt.Sync(); + } + + [TestMethod] + public void Sync_ValueTask_Generic_ShouldReturnValue() + { + // GIVEN + var vt = new ValueTask(987); + + // WHEN + var result = vt.Sync(); + + // THEN + Assert.AreEqual(987, result); + } + + public TestContext TestContext { get; set; } + } +} diff --git a/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskPoolTests.cs b/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskPoolTests.cs new file mode 100644 index 0000000..eef8cd9 --- /dev/null +++ b/TechnitiumLibrary.Tests/TechnitiumLibrary/TaskPoolTests.cs @@ -0,0 +1,153 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Threading.Tasks; + +namespace TechnitiumLibrary.Tests.TechnitiumLibrary +{ + [TestClass] + public sealed class TaskPoolTests + { + [TestMethod] + public async Task TryQueueTask_ShouldExecuteQueuedTask() + { + // GIVEN + var pool = new TaskPool(queueSize: 10, maximumConcurrencyLevel: 2); + var completer = new TaskCompletionSource(); + + // WHEN + var queued = pool.TryQueueTask(_ => + { + completer.SetResult(true); + return Task.CompletedTask; + }); + + // THEN + Assert.IsTrue(queued, "Task should be accepted into queue."); + Assert.IsTrue(await completer.Task, "Task must execute."); + } + + [TestMethod] + public async Task ShouldProcessMultipleTasksConcurrently_WhenAllowed() + { + // GIVEN + var parallelism = Environment.ProcessorCount; + var pool = new TaskPool(queueSize: 64, maximumConcurrencyLevel: parallelism); + + var counter = 0; + var completion = new TaskCompletionSource(); + var lockObj = new object(); + + int total = parallelism; + + // WHEN + for (int i = 0; i < total; i++) + { + pool.TryQueueTask(_ => + { + lock (lockObj) + counter++; + + if (counter == total) + completion.SetResult(true); + + return Task.CompletedTask; + }); + } + + // THEN + Assert.IsTrue(await completion.Task, "All tasks must execute."); + Assert.AreEqual(total, counter, "All queued tasks must run."); + } + + [TestMethod] + public async Task TasksShouldStopAfterDispose() + { + // GIVEN + var pool = new TaskPool(queueSize: 10, maximumConcurrencyLevel: 1); + + var executedBeforeDispose = new TaskCompletionSource(); + var wasExecutedAfterDispose = false; + + pool.TryQueueTask(_ => + { + executedBeforeDispose.SetResult(true); + return Task.CompletedTask; + }); + + await executedBeforeDispose.Task; + + // WHEN + pool.Dispose(); + var acceptedPostDispose = pool.TryQueueTask(_ => + { + wasExecutedAfterDispose = true; + return Task.CompletedTask; + }); + + // THEN + Assert.IsFalse(acceptedPostDispose, "After disposal, queue must reject writes."); + Assert.IsFalse(wasExecutedAfterDispose, "Tasks queued after Dispose must not run."); + } + + [TestMethod] + public void Ctor_ShouldUseDefaultConcurrency_WhenValueIsLessThanOne() + { + // GIVEN + WHEN + var pool = new TaskPool(queueSize: 10, maximumConcurrencyLevel: -1); + + // THEN + Assert.IsGreaterThanOrEqualTo(1, +pool.MaximumConcurrencyLevel, "Concurrency must fallback to processor count."); + } + + [TestMethod] + public void TryQueueTask_ShouldThrow_WhenTaskIsNull() + { + // GIVEN + var pool = new TaskPool(); + + // WHEN + THEN + Assert.ThrowsExactly(() => pool.TryQueueTask(null)); + } + + [TestMethod] + public async Task TaskShouldReceiveStateObject() + { + // GIVEN + var pool = new TaskPool(); + var completion = new TaskCompletionSource(); + + var expectedState = "STATE"; + var capturedState = default(string); + + // WHEN + pool.TryQueueTask(obj => + { + capturedState = obj as string; + completion.SetResult(true); + return Task.CompletedTask; + }, expectedState); + + await completion.Task; + + // THEN + Assert.AreEqual(expectedState, capturedState, "State parameter must propagate through execution."); + } + + [TestMethod] + [System.Diagnostics.CodeAnalysis.SuppressMessage("Design", "MSTEST0032:Assertion condition is always true", Justification = "Multiple Dispose must not throw")] + public void DisposeMustBeIdempotent() + { + // GIVEN + var pool = new TaskPool(); + + // WHEN + pool.Dispose(); + pool.Dispose(); + pool.Dispose(); + + // THEN + Assert.IsTrue(true, "Dispose must not throw."); + } + } +} diff --git a/TechnitiumLibrary.sln b/TechnitiumLibrary.sln index 9cbcda3..9c24271 100644 --- a/TechnitiumLibrary.sln +++ b/TechnitiumLibrary.sln @@ -25,6 +25,13 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TechnitiumLibrary", "Techni EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TechnitiumLibrary.Security.OTP", "TechnitiumLibrary.Security.OTP\TechnitiumLibrary.Security.OTP.csproj", "{72AF4EB6-EB81-4655-9998-8BF24B304614}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TechnitiumLibrary.Tests", "TechnitiumLibrary.Tests\TechnitiumLibrary.Tests.csproj", "{FD16DA2F-1446-45BA-929A-89D4E233F4C2}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{02EA681E-C7D8-13C7-8484-4AC65E1B71E8}" + ProjectSection(SolutionItems) = preProject + .editorconfig = .editorconfig + EndProjectSection +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -75,6 +82,10 @@ Global {72AF4EB6-EB81-4655-9998-8BF24B304614}.Debug|Any CPU.Build.0 = Debug|Any CPU {72AF4EB6-EB81-4655-9998-8BF24B304614}.Release|Any CPU.ActiveCfg = Release|Any CPU {72AF4EB6-EB81-4655-9998-8BF24B304614}.Release|Any CPU.Build.0 = Release|Any CPU + {FD16DA2F-1446-45BA-929A-89D4E233F4C2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {FD16DA2F-1446-45BA-929A-89D4E233F4C2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {FD16DA2F-1446-45BA-929A-89D4E233F4C2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {FD16DA2F-1446-45BA-929A-89D4E233F4C2}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/TechnitiumLibrary/Base32.cs b/TechnitiumLibrary/Base32.cs index e5d6b8b..13fc225 100644 --- a/TechnitiumLibrary/Base32.cs +++ b/TechnitiumLibrary/Base32.cs @@ -154,6 +154,19 @@ private static string Encode(Span data, char[] map, bool skipPadding) private static byte[] Decode(string data, int[] rmap) { + if (data is null) + { + throw new ArgumentNullException(nameof(data)); + } + if (data == string.Empty) + { + return Array.Empty(); + } + if (data.Contains(' ')) + { + throw new ArgumentException("Base32 string cannot contain spaces.", nameof(data)); + } + byte[] buffer; int paddingCount = 0; @@ -317,4 +330,4 @@ public static byte[] FromBase32HexString(string data) #endregion } -} +} \ No newline at end of file diff --git a/TechnitiumLibrary/BinaryNumber.cs b/TechnitiumLibrary/BinaryNumber.cs index 747a20b..833c0ec 100644 --- a/TechnitiumLibrary/BinaryNumber.cs +++ b/TechnitiumLibrary/BinaryNumber.cs @@ -44,7 +44,13 @@ public BinaryNumber(Stream s) public BinaryNumber(BinaryReader bR) { - _value = bR.ReadBytes(bR.Read7BitEncodedInt()); + var length = bR.Read7BitEncodedInt(); + var data = bR.ReadBytes(length); + + if (data.Length != length) + throw new EndOfStreamException("Not enough bytes in stream to build BinaryNumber."); + + _value = data; } #endregion diff --git a/TechnitiumLibrary/CollectionExtensions.cs b/TechnitiumLibrary/CollectionExtensions.cs index 91a04cd..6def981 100644 --- a/TechnitiumLibrary/CollectionExtensions.cs +++ b/TechnitiumLibrary/CollectionExtensions.cs @@ -40,8 +40,13 @@ public static void Shuffle(this IList array) public static IReadOnlyList Convert(this IReadOnlyList array, Func convert) { - T2[] newArray = new T2[array.Count]; + if (array is null) + throw new ArgumentNullException(nameof(array)); + + if (convert is null) + throw new ArgumentNullException(nameof(convert)); + T2[] newArray = new T2[array.Count]; for (int i = 0; i < array.Count; i++) newArray[i] = convert(array[i]); @@ -50,9 +55,14 @@ public static IReadOnlyList Convert(this IReadOnlyList array, Fu public static IReadOnlyCollection Convert(this IReadOnlyCollection collection, Func convert) { + if (collection is null) + throw new ArgumentNullException(nameof(collection)); + + if (convert is null) + throw new ArgumentNullException(nameof(convert)); + T2[] newArray = new T2[collection.Count]; int i = 0; - foreach (T1 item in collection) newArray[i++] = convert(item); diff --git a/TechnitiumLibrary/TaskPool.cs b/TechnitiumLibrary/TaskPool.cs index bf1b479..c2fd8ae 100644 --- a/TechnitiumLibrary/TaskPool.cs +++ b/TechnitiumLibrary/TaskPool.cs @@ -102,6 +102,10 @@ public bool TryQueueTask(Func task) public bool TryQueueTask(Func task, object state) { + if (task is null) + { + throw new ArgumentNullException(nameof(task)); + } return _channelWriter.TryWrite((task, state)); }