diff --git a/src/CommonLib/Helpers.cs b/src/CommonLib/Helpers.cs index 6b89ae9c9..b4d886967 100644 --- a/src/CommonLib/Helpers.cs +++ b/src/CommonLib/Helpers.cs @@ -12,12 +12,14 @@ using SharpHoundCommonLib.Processors; using Microsoft.Win32; using System.Threading.Tasks; +using System.Threading; namespace SharpHoundCommonLib { public static class Helpers { private static readonly HashSet Groups = new() { "268435456", "268435457", "536870912", "536870913" }; private static readonly HashSet Computers = new() { "805306369" }; private static readonly HashSet Users = new() { "805306368", "805306370" }; + private static readonly double MaxTimeSpanTicks = (double)TimeSpan.MaxValue.Ticks - 1_000; private static readonly Regex DCReplaceRegex = new("DC=", RegexOptions.IgnoreCase | RegexOptions.Compiled); private static readonly Regex SPNRegex = new(@".*\/.*", RegexOptions.Compiled); @@ -318,15 +320,28 @@ public static string DumpDirectoryObject(this IDirectoryObject directoryObject) return builder.ToString(); } + public static TimeSpan BackoffWithDecorrelatedJitter(int attempt, TimeSpan baseDelay, TimeSpan maxDelay) { + // Decorrelated Jitter Backoff - see https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ + var temp = Math.Min(maxDelay.Ticks, baseDelay.Ticks * (attempt * attempt)); + temp = temp / 2 + RandomUtils.Between(0, temp / 2); + var ticksToDelay = Math.Min(maxDelay.Ticks, RandomUtils.Between(baseDelay.Ticks, temp * 3)); + + // This ensures that a TimeSpan can be created with the ticks amount as TimeSpan uses a long. + return double.IsInfinity(ticksToDelay) ? TimeSpan.FromTicks((long)MaxTimeSpanTicks) : + TimeSpan.FromTicks((long)Math.Min(MaxTimeSpanTicks, ticksToDelay)); + } + /// /// Attempt an action a number of times, quietly eating a specific exception until the last attempt if it throws. /// /// /// /// - public static async Task RetryOnException(Func action, int retryCount, ILogger logger = null) where T : Exception { + public static async Task RetryOnException(Func action, int retryCount, TimeSpan? baseDelay = null, TimeSpan? maxDelay = null, ILogger logger = null) where T : Exception { int attempt = 0; bool success = false; + baseDelay ??= TimeSpan.FromSeconds(1); + maxDelay ??= TimeSpan.FromSeconds(30); do { try { await action(); @@ -337,9 +352,33 @@ public static async Task RetryOnException(Func action, int retryCount, logger?.LogDebug(e, "Exception caught, retrying attempt {Attempt}", attempt); if (attempt >= retryCount) throw; + + var delay = BackoffWithDecorrelatedJitter(attempt, baseDelay.Value, maxDelay.Value); + await Task.Delay(delay); } } while (!success && attempt < retryCount); } + + public static async Task RetryOnException(Func action, int retryCount, TimeSpan? baseDelay = null, TimeSpan? maxDelay = null, ILogger logger = null) where T : Exception { + int attempt = 0; + baseDelay ??= TimeSpan.FromSeconds(1); + maxDelay ??= TimeSpan.FromSeconds(30); + do { + try { + return action(); + } + catch (T e) { + attempt++; + logger?.LogDebug(e, "Exception caught, retrying attempt {Attempt}", attempt); + if (attempt >= retryCount) + throw; + var delay = BackoffWithDecorrelatedJitter(attempt, baseDelay.Value, maxDelay.Value); + await Task.Delay(delay); + } + } while (attempt < retryCount); + + throw new InvalidOperationException($"You really shouldn't be here, {nameof(RetryOnException)} isn't working as intended."); + } } public class ParsedGPLink { diff --git a/src/CommonLib/LdapConnectionPool.cs b/src/CommonLib/LdapConnectionPool.cs index bc0aede32..04e99be5f 100644 --- a/src/CommonLib/LdapConnectionPool.cs +++ b/src/CommonLib/LdapConnectionPool.cs @@ -37,9 +37,6 @@ internal class LdapConnectionPool : IDisposable { private const int MaxRetries = 3; private static readonly ConcurrentDictionary DCInfoCache = new(); - // Tracks domains we know we've determined we shouldn't try to connect to - private static readonly ConcurrentHashSet _excludedDomains = new(); - public LdapConnectionPool(string identifier, string poolIdentifier, LdapConfig config, IPortScanner scanner = null, NativeMethods nativeMethods = null, ILogger log = null) { _connections = new ConcurrentBag(); @@ -693,7 +690,7 @@ private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControll public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetConnectionAsync() { - if (_excludedDomains.Contains(_identifier)) { + if (LdapUtils.IsExcludedDomain(_identifier)) { return (false, null, $"Identifier {_identifier} excluded for connection attempt"); } @@ -727,7 +724,7 @@ private bool CallDsGetDcName(string domainName, out NetAPIStructs.DomainControll public async Task<(bool Success, LdapConnectionWrapper ConnectionWrapper, string Message)> GetGlobalCatalogConnectionAsync() { - if (_excludedDomains.Contains(_identifier)) { + if (LdapUtils.IsExcludedDomain(_identifier)) { return (false, null, $"Identifier {_identifier} excluded for connection attempt"); } @@ -813,7 +810,7 @@ await CreateLdapConnection(tempDomainName, globalCatalog) is (true, var connecti _log.LogDebug( "Could not get domain object from GetDomain, unable to create ldap connection for domain {Domain}", _identifier); - _excludedDomains.Add(_identifier); + LdapUtils.AddExcludedDomain(_identifier); return (false, null, "Unable to get domain object for further strategies"); } @@ -852,7 +849,7 @@ await CreateLdapConnection(tempDomainName, globalCatalog) is (true, var connecti catch (Exception e) { _log.LogInformation(e, "We will not be able to connect to domain {Domain} by any strategy, leaving it.", _identifier); - _excludedDomains.Add(_identifier); + LdapUtils.AddExcludedDomain(_identifier); } return (false, null, "All attempted connections failed"); diff --git a/src/CommonLib/LdapUtils.cs b/src/CommonLib/LdapUtils.cs index 14612da12..a13938e92 100644 --- a/src/CommonLib/LdapUtils.cs +++ b/src/CommonLib/LdapUtils.cs @@ -30,6 +30,9 @@ public class LdapUtils : ILdapUtils { private static ConcurrentDictionary _domainCache = new(); private static ConcurrentHashSet _domainControllers = new(StringComparer.OrdinalIgnoreCase); private static ConcurrentHashSet _unresolvablePrincipals = new(StringComparer.OrdinalIgnoreCase); + + // Tracks Domains we know we've determined we shouldn't try to connect to + private static ConcurrentHashSet _excludedDomains = new(StringComparer.OrdinalIgnoreCase); private static readonly ConcurrentDictionary DomainToForestCache = new(StringComparer.OrdinalIgnoreCase); @@ -50,7 +53,7 @@ private readonly ConcurrentDictionary private readonly ILogger _log; private readonly IPortScanner _portScanner; private readonly NativeMethods _nativeMethods; - private readonly string _nullCacheKey = Guid.NewGuid().ToString(); + private static readonly string _nullCacheKey = Guid.NewGuid().ToString(); private static readonly Regex SIDRegex = new(@"^(S-\d+-\d+-\d+-\d+-\d+-\d+)(-\d+)?$"); private readonly string[] _translateNames = { "Administrator", "admin" }; @@ -506,12 +509,14 @@ public bool GetDomain(string domainName, out Domain domain) { : new DirectoryContext(DirectoryContextType.Domain); // Blocking External Call - domain = Domain.GetDomain(context); + domain = Helpers.RetryOnException(() => Domain.GetDomain(context), 2).GetAwaiter().GetResult(); if (domain == null) return false; _domainCache.TryAdd(cacheKey, domain); return true; } catch (Exception e) { + // The Static GetDomain Function ran into an issue requiring to exclude a domain as it would continuously + // try to connect to a domain that it could not connect to. This method may also need the same logic. _log.LogDebug(e, "GetDomain call failed for domain name {Name}", domainName); domain = null; return false; @@ -519,7 +524,13 @@ public bool GetDomain(string domainName, out Domain domain) { } public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domain domain) { + var cacheKey = domainName ?? _nullCacheKey; if (_domainCache.TryGetValue(domainName, out domain)) return true; + if (IsExcludedDomain(domainName)) { + Logging.Logger.LogDebug("Domain: {DomainName} has been excluded for collection. Skipping", domainName); + domain = null; + return false; + } try { DirectoryContext context; @@ -535,14 +546,17 @@ public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domai : new DirectoryContext(DirectoryContextType.Domain); // Blocking External Call - domain = Domain.GetDomain(context); + domain = Helpers.RetryOnException(() => Domain.GetDomain(context), 2).GetAwaiter().GetResult(); if (domain == null) return false; - _domainCache.TryAdd(domainName, domain); + _domainCache.TryAdd(cacheKey, domain); return true; } catch (Exception e) { - Logging.Logger.LogDebug("Static GetDomain call failed for domain {DomainName}: {Error}", domainName, + Logging.Logger.LogDebug("Static GetDomain call failed, adding to exclusion, for domain {DomainName}: {Error}", domainName, e.Message); + // If a domain cannot be contacted, this will exclude the domain so that it does not continuously try to connect, and + // cause more timeouts. + AddExcludedDomain(cacheKey); domain = null; return false; } @@ -565,11 +579,13 @@ public bool GetDomain(out Domain domain) { : new DirectoryContext(DirectoryContextType.Domain); // Blocking External Call - domain = Domain.GetDomain(context); + domain = Helpers.RetryOnException(() => Domain.GetDomain(context), 2).GetAwaiter().GetResult(); _domainCache.TryAdd(_nullCacheKey, domain); return true; } catch (Exception e) { + // The Static GetDomain Function ran into an issue requiring to exclude a domain as it would continuously + // try to connect to a domain that it could not connect to. This method may also need the same logic. _log.LogDebug(e, "GetDomain call failed for blank domain"); domain = null; return false; @@ -1129,6 +1145,7 @@ public void ResetUtils() { _domainControllers = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase); _connectionPool?.Dispose(); _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner); + _excludedDomains = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase); } private IDirectoryObject CreateDirectoryEntry(string path) { @@ -1143,6 +1160,9 @@ public void Dispose() { _connectionPool?.Dispose(); } + public static bool IsExcludedDomain(string domain) => _excludedDomains.Contains(domain); + public static void AddExcludedDomain(string domain) => _excludedDomains.Add(domain); + internal static bool ResolveLabel(string objectIdentifier, string distinguishedName, string samAccountType, string[] objectClasses, int flags, out Label type) { type = Label.Base; diff --git a/src/CommonLib/RandomUtils.cs b/src/CommonLib/RandomUtils.cs new file mode 100644 index 000000000..e5a825dbc --- /dev/null +++ b/src/CommonLib/RandomUtils.cs @@ -0,0 +1,20 @@ +using System; +using System.Threading; + +namespace SharpHoundCommonLib; + +public static class RandomUtils { + private static readonly ThreadLocal Random = new(() => new Random()); + + public static long NextLong() => LongRandom(0, long.MaxValue); + + private static long LongRandom(long min, long max) { + var buf = new byte[8]; + Random.Value.NextBytes(buf); + var longRand = BitConverter.ToInt64(buf, 0); + return Math.Abs(longRand % (max - min)) + min; + } + + public static double Between(double minValue, double maxValue) => Random.Value.NextDouble() * (maxValue - minValue) + minValue; + public static long Between(long minValue, long maxValue) => LongRandom(minValue, maxValue); +} \ No newline at end of file diff --git a/test/unit/LdapConnectionPoolTest.cs b/test/unit/LdapConnectionPoolTest.cs index 5ac532c25..bede09089 100644 --- a/test/unit/LdapConnectionPoolTest.cs +++ b/test/unit/LdapConnectionPoolTest.cs @@ -7,39 +7,16 @@ public class LdapConnectionPoolTest { - private static void AddExclusionDomain(string identifier) { - var excludedDomainsField = typeof(LdapConnectionPool) - .GetField("_excludedDomains", BindingFlags.Static | BindingFlags.NonPublic); - - var excludedDomains = (ConcurrentHashSet)excludedDomainsField.GetValue(null); - - excludedDomains.Add(identifier); - } - [Fact] - public async Task LdapConnectionPool_ExcludedDomains_ShouldExitEarly() + public async Task LdapConnectionPool_Static_GetDomain_Add_To_ExcludedDomains_ShouldExitEarly() { var mockLogger = new Mock(); var ldapConfig = new LdapConfig(); var connectionPool = new ConnectionPoolManager(ldapConfig, mockLogger.Object); - AddExclusionDomain("excludedDomain.com"); var connectAttempt = await connectionPool.TestDomainConnection("excludedDomain.com", false); Assert.False(connectAttempt.Success); Assert.Contains("excluded for connection attempt", connectAttempt.Message); } - - [Fact] - public async Task LdapConnectionPool_ExcludedDomains_NonExcludedShouldntExit() - { - var mockLogger = new Mock(); - var ldapConfig = new LdapConfig(); - var connectionPool = new ConnectionPoolManager(ldapConfig, mockLogger.Object); - - AddExclusionDomain("excludedDomain.com"); - var connectAttempt = await connectionPool.TestDomainConnection("perfectlyValidDomain.com", false); - - Assert.DoesNotContain("excluded for connection attempt", connectAttempt.Message); - } } \ No newline at end of file diff --git a/test/unit/TimeoutTests.cs b/test/unit/TimeoutTests.cs index d5b06a204..ee4d1fd70 100644 --- a/test/unit/TimeoutTests.cs +++ b/test/unit/TimeoutTests.cs @@ -251,4 +251,21 @@ public async Task ExecuteWithTimeout_Task_T_ParentTokenCancel() { Assert.False(result.IsSuccess); Assert.Equal("Cancellation requested", result.Error); } + + [Theory] + [InlineData(0, 2, 30, 2, 6)] + [InlineData(5, 2, 200, 1, 192)] + [InlineData(5, 5, 500, 5, 480)] + [InlineData(0, 2, 1, 1, 1)] + [InlineData(5, 2, 1, 1, 1)] + [InlineData(5, 2, 2, 2, 2)] + [InlineData(5, 30, 30, 30, 30)] + public void DecorrelatedTimeSpan_BetweenExpected(int attempt, int baseDelayValue, int maxDelayValue, double expectedLowerBound, double expectedUpperBound) { + var baseDelay = TimeSpan.FromTicks(baseDelayValue); + var maxDelay = TimeSpan.FromTicks(maxDelayValue); + for (var trials = 0; trials < 500; trials++) { + var delay = SharpHoundCommonLib.Helpers.BackoffWithDecorrelatedJitter(attempt, baseDelay, maxDelay); + Assert.InRange(delay.Ticks, expectedLowerBound, expectedUpperBound); + } + } } \ No newline at end of file