Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ public class NonBlockingThreadPool<Environment: ConcurrencyPlatform>: ComputeThr
var cancelledStorage: AtomicUInt64
var blockedCountStorage: AtomicUInt64
var spinningState: AtomicUInt64
var condition: NonblockingCondition<Environment>
var waitingMutex: [Environment.ConditionMutex] // TODO: modify condition to add per-thread wakeup
var externalWaitingMutex: Environment.ConditionMutex
let condition: NonblockingCondition<Environment>
let waitingMutex: [Environment.ConditionMutex] // TODO: modify condition to add per-thread wakeup
let externalWaitingMutex: Environment.ConditionMutex
var threads: [Environment.Thread]

private let perThreadKey = Environment.ThreadLocalStorage.makeKey(
Expand Down Expand Up @@ -451,6 +451,11 @@ fileprivate final class PerThreadState<Environment: ConcurrencyPlatform> {
init(threadId: Int, pool: NonBlockingThreadPool<Environment>) {
self.threadId = threadId
self.pool = pool
self.totalThreadCount = pool.totalThreadCount
self.workerThreadCount = pool.totalThreadCount - pool.externalFastPathThreadCount
self.coprimes = pool.coprimes
self.queues = pool.queues
self.condition = pool.condition
self.rng = PCGRandomNumberGenerator(state: UInt64(threadId))
}
let threadId: Int
Expand All @@ -462,35 +467,41 @@ fileprivate final class PerThreadState<Environment: ConcurrencyPlatform> {
// possible to provide a safer API that doesn't leak by default without inducing an extra pointer
// dereference on critical paths. :-(

let totalThreadCount: Int
let workerThreadCount: Int
let coprimes: [Int]
let queues: [NonBlockingThreadPool<Environment>.Queue]
let condition: NonblockingCondition<Environment>

var rng: PCGRandomNumberGenerator

var isCancelled: Bool { pool.cancelled }

func steal() -> Task? {
let r = rng.next()
var selectedThreadId = Int(r.reduced(into: UInt64(pool.totalThreadCount)))
let step = pool.coprimes[Int(r.reduced(into: UInt64(pool.coprimes.count)))]
var selectedThreadId = Int(r.reduced(into: UInt64(totalThreadCount)))
let step = coprimes[Int(r.reduced(into: UInt64(coprimes.count)))]
assert(
step < pool.totalThreadCount, "step: \(step), pool threadcount: \(pool.totalThreadCount)")
step < totalThreadCount, "step: \(step), pool threadcount: \(totalThreadCount)")

for i in 0..<pool.totalThreadCount {
for i in 0..<totalThreadCount {
assert(
selectedThreadId < pool.totalThreadCount,
"\(selectedThreadId) is too big on iteration \(i); max: \(pool.totalThreadCount), step: \(step)"
selectedThreadId < totalThreadCount,
"\(selectedThreadId) is too big on iteration \(i); max: \(totalThreadCount), step: \(step)"
)
if let task = pool.queues[selectedThreadId].popBack() {
if let task = queues[selectedThreadId].popBack() {
return task
}
selectedThreadId += step
if selectedThreadId >= pool.totalThreadCount {
selectedThreadId -= pool.totalThreadCount
if selectedThreadId >= totalThreadCount {
selectedThreadId -= totalThreadCount
}
}
return nil
}

func spin() -> Task? {
let spinCount = pool.threads.count > 0 ? Constants.spinCount / pool.threads.count : 0
let spinCount = workerThreadCount > 0 ? Constants.spinCount / workerThreadCount : 0

if pool.shouldStartSpinning() {
// Call steal spin_count times; break if steal returns something.
Expand All @@ -511,12 +522,12 @@ fileprivate final class PerThreadState<Environment: ConcurrencyPlatform> {

func parkUntilWorkAvailable() -> Task? {
// Already did a best-effort emptiness check in steal, so prepare for blocking.
pool.condition.preWait()
condition.preWait()
// Now we do a reliable emptiness check.
if let nonEmptyQueueIndex = findNonEmptyQueueIndex() {
pool.condition.cancelWait()
condition.cancelWait()
// Steal from `nonEmptyQueueIndex`.
return pool.queues[nonEmptyQueueIndex].popBack()
return queues[nonEmptyQueueIndex].popBack()
}
let blockedCount = pool.blockedCountStorage.increment() + 1 // increment returns old value.
if blockedCount == pool.threads.count {
Expand All @@ -526,21 +537,21 @@ fileprivate final class PerThreadState<Environment: ConcurrencyPlatform> {
pool.condition.cancelWait()
return nil
}
pool.condition.commitWait(threadId)
condition.commitWait(threadId)
_ = pool.blockedCountStorage.decrement()
return nil
}

private func findNonEmptyQueueIndex() -> Int? {
let r = rng.next()
let increment =
pool.totalThreadCount == 1 ? 1 : pool.coprimes[Int(r.reduced(into: UInt64(pool.coprimes.count)))]
var threadIndex = Int(r.reduced(into: UInt64(pool.totalThreadCount)))
for _ in 0..<pool.totalThreadCount {
if !pool.queues[threadIndex].isEmpty { return threadIndex }
totalThreadCount == 1 ? 1 : coprimes[Int(r.reduced(into: UInt64(coprimes.count)))]
var threadIndex = Int(r.reduced(into: UInt64(totalThreadCount)))
for _ in 0..<totalThreadCount {
if !queues[threadIndex].isEmpty { return threadIndex }
threadIndex += increment
if threadIndex >= pool.totalThreadCount {
threadIndex -= pool.totalThreadCount
if threadIndex >= totalThreadCount {
threadIndex -= totalThreadCount
}
}
return nil
Expand Down