diff --git a/src/mutex.c b/src/mutex.c index 1f8d41a..48b4589 100644 --- a/src/mutex.c +++ b/src/mutex.c @@ -14,7 +14,8 @@ // Await the value of an `atomic_int`. -static bool await_atomic_int(atomic_int *var, timestamp_us_t timeout, int expected, int new_value, memory_order order) { +static bool + await_swap_atomic_int(atomic_int *var, timestamp_us_t timeout, int expected, int new_value, memory_order order) { do { int old_value = expected; if (atomic_compare_exchange_weak_explicit(var, &old_value, new_value, order, memory_order_relaxed)) { @@ -27,10 +28,10 @@ static bool await_atomic_int(atomic_int *var, timestamp_us_t timeout, int expect } // Atomically check the value exceeds a threshold and subtract. -static bool thresh_subtrach_atomic_int(atomic_int *var, int threshold, int subtract, memory_order order) { +static bool thresh_sub_atomic_int(atomic_int *var, int threshold, int sub, memory_order order) { while (1) { int old_value = atomic_load(var); - int new_value = old_value - subtract; + int new_value = old_value - sub; if (old_value < threshold) { return false; } else if (atomic_compare_exchange_weak_explicit(var, &old_value, new_value, order, memory_order_relaxed)) { @@ -41,6 +42,36 @@ static bool thresh_subtrach_atomic_int(atomic_int *var, int threshold, int subtr } } +// Atomically check the value does not exceed a threshold and add. +static bool thresh_add_atomic_int(atomic_int *var, int threshold, int add, memory_order order) { + while (1) { + int old_value = atomic_load(var); + int new_value = old_value + add; + if (old_value >= threshold || new_value >= threshold) { + return false; + } else if (atomic_compare_exchange_weak_explicit(var, &old_value, new_value, order, memory_order_relaxed)) { + return true; + } else { + sched_yield(); + } + } +} + +// Atomically check the value doesn't equal either illegal values and subtract. +static bool unequal_sub_atomic_int(atomic_int *var, int unequal0, int unequal1, int sub, memory_order order) { + while (1) { + int old_value = atomic_load(var); + int new_value = old_value - sub; + if (old_value == unequal0 || old_value == unequal1) { + return false; + } else if (atomic_compare_exchange_weak_explicit(var, &old_value, new_value, order, memory_order_relaxed)) { + return true; + } else { + sched_yield(); + } + } +} + // Initialise a mutex for unshared use. @@ -88,7 +119,7 @@ bool mutex_acquire(badge_err_t *ec, mutex_t *mutex, timestamp_us_t timeout) { } timeout += time_us(); // Await the shared portion to reach 0 and then lock. - if (await_atomic_int(&mutex->shares, timeout, 0, EXCLUSIVE_MAGIC, memory_order_acquire)) { + if (await_swap_atomic_int(&mutex->shares, timeout, 0, EXCLUSIVE_MAGIC, memory_order_acquire)) { // If that succeeds, the mutex was acquired. badge_err_set_ok(ec); return true; @@ -106,10 +137,12 @@ bool mutex_release(badge_err_t *ec, mutex_t *mutex) { badge_err_set(ec, ELOC_UNKNOWN, ECAUSE_ILLEGAL); return false; } - if (thresh_subtrach_atomic_int(&mutex->shares, EXCLUSIVE_MAGIC, EXCLUSIVE_MAGIC, memory_order_release)) { + if (thresh_sub_atomic_int(&mutex->shares, EXCLUSIVE_MAGIC, EXCLUSIVE_MAGIC, memory_order_release)) { + // Successful release. badge_err_set_ok(ec); return true; } else { + // Mutex was not taken exclusively. badge_err_set(ec, ELOC_UNKNOWN, ECAUSE_ILLEGAL); return false; } @@ -128,16 +161,13 @@ bool mutex_acquire_shared(badge_err_t *ec, mutex_t *mutex, timestamp_us_t timeou } timeout += time_us(); // Take a share. - int val = atomic_fetch_add_explicit(&mutex->shares, 1, memory_order_acquire); - // Await the lock to be released. - if (val < EXCLUSIVE_MAGIC) { + if (thresh_add_atomic_int(&mutex->shares, EXCLUSIVE_MAGIC, 1, memory_order_acquire)) { // If that succeeds, the mutex was successfully acquired. badge_err_set_ok(ec); return true; } else { // If that fails, abort trying to lock. badge_err_set(ec, ELOC_UNKNOWN, ECAUSE_TIMEOUT); - atomic_fetch_sub_explicit(&mutex->shares, 1, memory_order_relaxed); return false; } } @@ -150,13 +180,13 @@ bool mutex_release_shared(badge_err_t *ec, mutex_t *mutex) { return false; } int old = atomic_fetch_sub_explicit(&mutex->shares, 1, memory_order_release); - if (old == 0 || old == EXCLUSIVE_MAGIC) { + if (unequal_sub_atomic_int(&mutex->shares, 0, EXCLUSIVE_MAGIC, 1, memory_order_release)) { // Prevent the counter from underflowing. badge_err_set(ec, ELOC_UNKNOWN, ECAUSE_ILLEGAL); - atomic_fetch_add_explicit(&mutex->shares, 1, memory_order_relaxed); return false; } else { // Successful release. + badge_err_set_ok(ec); return true; } }