diff --git a/Sources/PenguinParallel/NonblockingThreadPool/NonBlockingThreadPool.swift b/Sources/PenguinParallel/NonblockingThreadPool/NonBlockingThreadPool.swift index ca672fe5..446a319a 100644 --- a/Sources/PenguinParallel/NonblockingThreadPool/NonBlockingThreadPool.swift +++ b/Sources/PenguinParallel/NonblockingThreadPool/NonBlockingThreadPool.swift @@ -69,9 +69,9 @@ public class NonBlockingThreadPool: ComputeThr var cancelledStorage: AtomicUInt64 var blockedCountStorage: AtomicUInt64 var spinningState: AtomicUInt64 - var condition: NonblockingCondition - var waitingMutex: [Environment.ConditionMutex] // TODO: modify condition to add per-thread wakeup - var externalWaitingMutex: Environment.ConditionMutex + let condition: NonblockingCondition + let waitingMutex: [Environment.ConditionMutex] // TODO: modify condition to add per-thread wakeup + let externalWaitingMutex: Environment.ConditionMutex var threads: [Environment.Thread] private let perThreadKey = Environment.ThreadLocalStorage.makeKey( @@ -451,6 +451,11 @@ fileprivate final class PerThreadState { init(threadId: Int, pool: NonBlockingThreadPool) { self.threadId = threadId self.pool = pool + self.totalThreadCount = pool.totalThreadCount + self.workerThreadCount = pool.totalThreadCount - pool.externalFastPathThreadCount + self.coprimes = pool.coprimes + self.queues = pool.queues + self.condition = pool.condition self.rng = PCGRandomNumberGenerator(state: UInt64(threadId)) } let threadId: Int @@ -462,35 +467,41 @@ fileprivate final class PerThreadState { // possible to provide a safer API that doesn't leak by default without inducing an extra pointer // dereference on critical paths. :-( + let totalThreadCount: Int + let workerThreadCount: Int + let coprimes: [Int] + let queues: [NonBlockingThreadPool.Queue] + let condition: NonblockingCondition + var rng: PCGRandomNumberGenerator var isCancelled: Bool { pool.cancelled } func steal() -> Task? { let r = rng.next() - var selectedThreadId = Int(r.reduced(into: UInt64(pool.totalThreadCount))) - let step = pool.coprimes[Int(r.reduced(into: UInt64(pool.coprimes.count)))] + var selectedThreadId = Int(r.reduced(into: UInt64(totalThreadCount))) + let step = coprimes[Int(r.reduced(into: UInt64(coprimes.count)))] assert( - step < pool.totalThreadCount, "step: \(step), pool threadcount: \(pool.totalThreadCount)") + step < totalThreadCount, "step: \(step), pool threadcount: \(totalThreadCount)") - for i in 0..= pool.totalThreadCount { - selectedThreadId -= pool.totalThreadCount + if selectedThreadId >= totalThreadCount { + selectedThreadId -= totalThreadCount } } return nil } func spin() -> Task? { - let spinCount = pool.threads.count > 0 ? Constants.spinCount / pool.threads.count : 0 + let spinCount = workerThreadCount > 0 ? Constants.spinCount / workerThreadCount : 0 if pool.shouldStartSpinning() { // Call steal spin_count times; break if steal returns something. @@ -511,12 +522,12 @@ fileprivate final class PerThreadState { func parkUntilWorkAvailable() -> Task? { // Already did a best-effort emptiness check in steal, so prepare for blocking. - pool.condition.preWait() + condition.preWait() // Now we do a reliable emptiness check. if let nonEmptyQueueIndex = findNonEmptyQueueIndex() { - pool.condition.cancelWait() + condition.cancelWait() // Steal from `nonEmptyQueueIndex`. - return pool.queues[nonEmptyQueueIndex].popBack() + return queues[nonEmptyQueueIndex].popBack() } let blockedCount = pool.blockedCountStorage.increment() + 1 // increment returns old value. if blockedCount == pool.threads.count { @@ -526,7 +537,7 @@ fileprivate final class PerThreadState { pool.condition.cancelWait() return nil } - pool.condition.commitWait(threadId) + condition.commitWait(threadId) _ = pool.blockedCountStorage.decrement() return nil } @@ -534,13 +545,13 @@ fileprivate final class PerThreadState { private func findNonEmptyQueueIndex() -> Int? { let r = rng.next() let increment = - pool.totalThreadCount == 1 ? 1 : pool.coprimes[Int(r.reduced(into: UInt64(pool.coprimes.count)))] - var threadIndex = Int(r.reduced(into: UInt64(pool.totalThreadCount))) - for _ in 0..= pool.totalThreadCount { - threadIndex -= pool.totalThreadCount + if threadIndex >= totalThreadCount { + threadIndex -= totalThreadCount } } return nil