Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cleanup internal nparray_to_* methods #80

Merged
merged 8 commits into from
Jan 17, 2025
46 changes: 24 additions & 22 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,24 @@ impl CodecPipelineImpl {
}

fn py_untyped_array_to_array_object<'a>(
value: &Bound<'a, PyUntypedArray>,
value: &'a Bound<'_, PyUntypedArray>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oof good catch. this is why unsafe is scary, this was extremely wrong.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, is there a tl;dr for why this was wrong?

Copy link
Collaborator

@flying-sheep flying-sheep Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I’ll rename the lifetimes for easier understanding.

First, Bound is a GIL binding, that means it’s a kind of reference1 that lives as long as a certain part of our code holds Python’s GIL (Global Interpreter Lock). This guarantees that nothing in Python land tries to mutate that object while we’re accessing it at the same time. So:

  • &'x Bound<'py, PyUntypedArray> means “a reference with lifetime 'x to a GIL binding with lifetime 'py to a PyArrayObject”. So the reference needs to live (most likely) shorter than the GIL binding2 or maximally equally long.
  • &'y PyArrayObject means “a reference with lifetime 'y to a PyArrayObject”.

The previous version returned a reference &'y PyArrayObject with 'y: 'py, which means that the returned reference is valid as long as the GIL binding is held. In reality the returned reference is derived from the input reference &'x. Remember that this &'x reference is probably shorter-lived than the GIL binding. So before this fix, nothing stopped us from dropping the &'x reference (because as far as rustc is concerned, nothing derives from it), creating a new, mutable reference to the same object, and mutating the object through that, while other parts of our code are still allowed to read that object through the &'y reference. Something like:

let mut value: Bound<'py, PyUntypedArray> = ...;

let array: &'py PyUntypedArray = {
    let readonly_ref = &value;
    py_untyped_array_to_array_object(readonly_ref)
}; // readonly_ref is dropped here

let mut_ref = value.borrow_mut();
thread::spawn(move || {
    write_into(mut_ref);  // race condition
}
thread::spawn(move || {
    println!("{?:array}");  // race condition
}

Footnotes

  1. I’ll only call regular Rust references “reference” after this to avoid confusion

  2. because before you reference something, you need to create it, and the thing can only be destroyed after you dropped your reference

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to say thanks for the detailed explanation! It still makes my head spin, which is a good reason to avoid unsafe in my own code 😅

) -> &'a PyArrayObject {
let array_object_ptr: *mut PyArrayObject = value.as_array_ptr();
unsafe {
// SAFETY: array_object_ptr cannot be null
&*array_object_ptr
}
let array_object: &'a PyArrayObject = unsafe {
// SAFETY: array_object_ptr cannot be null and the array object pointed to by array_object_ptr is valid for 'a
array_object_ptr
.as_ref()
.expect("pointer is convertible to a reference")
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved
};
array_object
}

fn nparray_to_slice<'a>(value: &'a Bound<'_, PyUntypedArray>) -> &'a [u8] {
fn nparray_to_slice<'a>(value: &'a Bound<'_, PyUntypedArray>) -> Result<&'a [u8], PyErr> {
if !value.is_c_contiguous() {
return Err(PyErr::new::<PyValueError, _>(
"input array must be a C contiguous array".to_string(),
));
}
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved
let array_object: &PyArrayObject = Self::py_untyped_array_to_array_object(value);
let array_data = array_object.data.cast::<u8>();
let array_len = value.len() * Self::pyarray_itemsize(value);
Expand All @@ -173,12 +181,17 @@ impl CodecPipelineImpl {
debug_assert!(!array_data.is_null());
std::slice::from_raw_parts(array_data, array_len)
};
slice
Ok(slice)
}

fn nparray_to_unsafe_cell_slice<'a>(
value: &'a Bound<'_, PyUntypedArray>,
) -> UnsafeCellSlice<'a, u8> {
) -> Result<UnsafeCellSlice<'a, u8>, PyErr> {
if !value.is_c_contiguous() {
return Err(PyErr::new::<PyValueError, _>(
"input array must be a C contiguous array".to_string(),
));
}
let array_object: &PyArrayObject = Self::py_untyped_array_to_array_object(value);
let array_data = array_object.data.cast::<u8>();
let array_len = value.len() * Self::pyarray_itemsize(value);
Expand All @@ -187,7 +200,7 @@ impl CodecPipelineImpl {
debug_assert!(!array_data.is_null());
std::slice::from_raw_parts_mut(array_data, array_len)
};
UnsafeCellSlice::new(output)
Ok(UnsafeCellSlice::new(output))
}
}

Expand Down Expand Up @@ -248,12 +261,7 @@ impl CodecPipelineImpl {
value: &Bound<'_, PyUntypedArray>,
) -> PyResult<()> {
// Get input array
if !value.is_c_contiguous() {
return Err(PyErr::new::<PyValueError, _>(
"input array must be a C contiguous array".to_string(),
));
}
let output = Self::nparray_to_unsafe_cell_slice(value);
let output = Self::nparray_to_unsafe_cell_slice(value)?;
let output_shape: Vec<u64> = value.shape_zarr()?;

// Adjust the concurrency based on the codec chain and the first chunk description
Expand Down Expand Up @@ -397,13 +405,7 @@ impl CodecPipelineImpl {
}

// Get input array
if !value.is_c_contiguous() {
return Err(PyErr::new::<PyValueError, _>(
"input array must be a C contiguous array".to_string(),
));
}

let input_slice = Self::nparray_to_slice(value);
let input_slice = Self::nparray_to_slice(value)?;
let input = if value.ndim() > 0 {
InputValue::Array(ArrayBytes::new_flen(Cow::Borrowed(input_slice)))
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fn test_nparray_to_unsafe_cell_slice_empty() -> PyResult<()> {
.call0()?
.extract()?;

let slice = CodecPipelineImpl::nparray_to_unsafe_cell_slice(&arr);
let slice = CodecPipelineImpl::nparray_to_unsafe_cell_slice(&arr)?;
assert!(slice.is_empty());
Ok(())
})
Expand Down
Loading