Skip to content

Commit

Permalink
fix(rust): Incorrect atomic ordering in Connector (#21341)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Feb 19, 2025
1 parent f2e4ba1 commit d6bb315
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions crates/polars-stream/src/async_primitives/connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ pub fn connector<T>() -> (Sender<T>, Receiver<T>) {
/*
For UnsafeCell safety, a sender may only set the FULL_BIT (giving exclusive
access to value to the receiver), and a receiver may only unset the FULL_BIT
(giving exclusive access back to the sender).
(giving exclusive access back to the sender). Setting/clearing the FULL_BIT
must be done with a Release ordering, and before reading/writing the value
the FULL_BIT must be checked with an Acquire ordering.
The exception is when the closed bit is set, at that point the unclosed
end has full exclusive access.
Expand Down Expand Up @@ -66,14 +68,14 @@ pub enum RecvError {
impl<T> Connector<T> {
unsafe fn poll_send(&self, value: &mut Option<T>, waker: &Waker) -> Poll<Result<(), T>> {
if let Some(v) = value.take() {
let mut state = self.state.load(Ordering::Relaxed);
let mut state = self.state.load(Ordering::Acquire);
if state & FULL_BIT == FULL_BIT {
self.send_waker.register(waker);
let (Ok(s) | Err(s)) = self.state.compare_exchange(
state,
state | WAITING_BIT,
Ordering::Release,
Ordering::Relaxed,
Ordering::Acquire, // Receiver updated, re-acquire.
);
state = s;
}
Expand Down Expand Up @@ -101,11 +103,13 @@ impl<T> Connector<T> {

unsafe {
self.value.get().write(MaybeUninit::new(value));
let state = self.state.swap(FULL_BIT, Ordering::AcqRel);
let state = self.state.swap(FULL_BIT, Ordering::Release);
if state & WAITING_BIT == WAITING_BIT {
self.recv_waker.wake();
}
if state & CLOSED_BIT == CLOSED_BIT {
// SAFETY: no synchronization needed, we are the only one left.
// Restore the closed bit we just overwrote.
self.state.store(CLOSED_BIT, Ordering::Relaxed);
return Err(SendError::Closed(self.value.get().read().assume_init()));
}
Expand All @@ -121,8 +125,8 @@ impl<T> Connector<T> {
let (Ok(s) | Err(s)) = self.state.compare_exchange(
state,
state | WAITING_BIT,
Ordering::Release,
Ordering::Acquire,
Ordering::Relaxed,
Ordering::Acquire, // Sender updated, re-acquire.
);
state = s;
}
Expand All @@ -138,11 +142,12 @@ impl<T> Connector<T> {
if state & FULL_BIT == FULL_BIT {
unsafe {
let ret = self.value.get().read().assume_init();
let state = self.state.swap(0, Ordering::Acquire);
let state = self.state.swap(0, Ordering::Release);
if state & WAITING_BIT == WAITING_BIT {
self.send_waker.wake();
}
if state & CLOSED_BIT == CLOSED_BIT {
// Restore the closed bit we just overwrote.
self.state.store(CLOSED_BIT, Ordering::Relaxed);
}
return Ok(ret);
Expand All @@ -159,7 +164,7 @@ impl<T> Connector<T> {
}

unsafe fn try_send(&self, value: T) -> Result<(), SendError<T>> {
self.try_send_impl(value, self.state.load(Ordering::Relaxed))
self.try_send_impl(value, self.state.load(Ordering::Acquire))
}

unsafe fn try_recv(&self) -> Result<T, RecvError> {
Expand All @@ -176,8 +181,8 @@ impl<T> Connector<T> {
/// # Safety
/// You may not access this connector anymore as a receiver after this call.
unsafe fn close_recv(&self) {
self.state.fetch_or(CLOSED_BIT, Ordering::Relaxed);
drop(self.try_recv());
let state = self.state.fetch_or(CLOSED_BIT, Ordering::Acquire);
drop(self.try_recv_impl(state));
self.send_waker.wake();
}
}
Expand Down

0 comments on commit d6bb315

Please sign in to comment.