Skip to content

Commit

Permalink
convert PyUntypedArray to Bound API
Browse files Browse the repository at this point in the history
  • Loading branch information
Icxolu authored and adamreichold committed Mar 16, 2024
1 parent 744a3f3 commit 9517ed9
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 22 deletions.
16 changes: 15 additions & 1 deletion src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::error::{
};
use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API};
use crate::slice_container::PySliceContainer;
use crate::untyped_array::PyUntypedArray;
use crate::untyped_array::{PyUntypedArray, PyUntypedArrayMethods};

/// A safe, statically-typed wrapper for NumPy's [`ndarray`][ndarray] class.
///
Expand Down Expand Up @@ -1480,6 +1480,20 @@ unsafe fn clone_elements<T: Element>(elems: &[T], data_ptr: &mut *mut T) {
}
}

/// Implementation of functionality for [`PyArray<T, D>`].
#[doc(alias = "PyArray")]
pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> {
/// Access an untyped representation of this array.
fn as_untyped(&self) -> &Bound<'py, PyUntypedArray>;
}

impl<'py, T, D> PyArrayMethods<'py, T, D> for Bound<'py, PyArray<T, D>> {
#[inline(always)]
fn as_untyped(&self) -> &Bound<'py, PyUntypedArray> {
unsafe { self.downcast_unchecked() }
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub use nalgebra;

pub use crate::array::{
get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5,
PyArray6, PyArrayDyn,
PyArray6, PyArrayDyn, PyArrayMethods,
};
pub use crate::array_like::{
AllowTypeChange, PyArrayLike, PyArrayLike0, PyArrayLike1, PyArrayLike2, PyArrayLike3,
Expand All @@ -111,7 +111,7 @@ pub use crate::error::{BorrowError, FromVecError, NotContiguousError};
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
pub use crate::strings::{PyFixedString, PyFixedUnicode};
pub use crate::sum_products::{dot, einsum, inner};
pub use crate::untyped_array::PyUntypedArray;
pub use crate::untyped_array::{PyUntypedArray, PyUntypedArrayMethods};

pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};

Expand Down
245 changes: 226 additions & 19 deletions src/untyped_array.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
//! Safe, untyped interface for NumPy's [N-dimensional arrays][ndarray]
//!
//! [ndarray]: https://numpy.org/doc/stable/reference/arrays.ndarray.html
use std::{os::raw::c_int, slice};
use std::slice;

use pyo3::{
ffi, pyobject_native_type_extract, pyobject_native_type_named, AsPyPointer, IntoPy, PyAny,
PyNativeType, PyObject, PyTypeInfo, Python,
ffi, pyobject_native_type_extract, pyobject_native_type_named, types::PyAnyMethods,
AsPyPointer, Bound, IntoPy, PyAny, PyNativeType, PyObject, PyTypeInfo, Python,
};

use crate::array::{PyArray, PyArrayMethods};
use crate::cold;
use crate::dtype::PyArrayDescr;
use crate::npyffi;
Expand Down Expand Up @@ -68,7 +69,7 @@ unsafe impl PyTypeInfo for PyUntypedArray {
unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
}

fn is_type_of(ob: &PyAny) -> bool {
fn is_type_of_bound(ob: &Bound<'_, PyAny>) -> bool {
unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) != 0 }
}
}
Expand All @@ -87,7 +88,7 @@ impl PyUntypedArray {
/// Returns a raw pointer to the underlying [`PyArrayObject`][npyffi::PyArrayObject].
#[inline]
pub fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject {
self.as_ptr() as _
self.as_borrowed().as_array_ptr()
}

/// Returns the `dtype` of the array.
Expand All @@ -109,16 +110,9 @@ impl PyUntypedArray {
///
/// [ndarray-dtype]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dtype.html
/// [PyArray_DTYPE]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DTYPE
#[inline]
pub fn dtype(&self) -> &PyArrayDescr {
unsafe {
let descr_ptr = (*self.as_array_ptr()).descr;
self.py().from_borrowed_ptr(descr_ptr as _)
}
}

#[inline(always)]
pub(crate) fn check_flags(&self, flags: c_int) -> bool {
unsafe { (*self.as_array_ptr()).flags & flags != 0 }
self.as_borrowed().dtype().into_gil_ref()
}

/// Returns `true` if the internal data of the array is contiguous,
Expand All @@ -142,18 +136,21 @@ impl PyUntypedArray {
/// assert!(!view.is_contiguous());
/// });
/// ```
#[inline]
pub fn is_contiguous(&self) -> bool {
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
self.as_borrowed().is_contiguous()
}

/// Returns `true` if the internal data of the array is Fortran-style/column-major contiguous.
#[inline]
pub fn is_fortran_contiguous(&self) -> bool {
self.check_flags(npyffi::NPY_ARRAY_F_CONTIGUOUS)
self.as_borrowed().is_fortran_contiguous()
}

/// Returns `true` if the internal data of the array is C-style/row-major contiguous.
#[inline]
pub fn is_c_contiguous(&self) -> bool {
self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS)
self.as_borrowed().is_c_contiguous()
}

/// Returns the number of dimensions of the array.
Expand All @@ -177,7 +174,7 @@ impl PyUntypedArray {
/// [PyArray_NDIM]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_NDIM
#[inline]
pub fn ndim(&self) -> usize {
unsafe { (*self.as_array_ptr()).nd as usize }
self.as_borrowed().ndim()
}

/// Returns a slice indicating how many bytes to advance when iterating along each axis.
Expand Down Expand Up @@ -246,12 +243,222 @@ impl PyUntypedArray {
}

/// Calculates the total number of elements in the array.
#[inline]
pub fn len(&self) -> usize {
self.shape().iter().product()
self.as_borrowed().len()
}

/// Returns `true` if the there are no elements in the array.
#[inline]
pub fn is_empty(&self) -> bool {
self.as_borrowed().is_empty()
}
}

/// Implementation of functionality for [`PyUntypedArray`].
#[doc(alias = "PyUntypedArray")]
pub trait PyUntypedArrayMethods<'py>: sealed::Sealed {
/// Returns a raw pointer to the underlying [`PyArrayObject`][npyffi::PyArrayObject].
fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject;

/// Returns the `dtype` of the array.
///
/// See also [`ndarray.dtype`][ndarray-dtype] and [`PyArray_DTYPE`][PyArray_DTYPE].
///
/// # Example
///
/// ```
/// use numpy::{dtype, PyArray};
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let array = PyArray::from_vec(py, vec![1_i32, 2, 3]);
///
/// assert!(array.dtype().is_equiv_to(dtype::<i32>(py)));
/// });
/// ```
///
/// [ndarray-dtype]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dtype.html
/// [PyArray_DTYPE]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DTYPE
fn dtype(&self) -> Bound<'py, PyArrayDescr>;

/// Returns `true` if the internal data of the array is contiguous,
/// indepedently of whether C-style/row-major or Fortran-style/column-major.
///
/// # Example
///
/// ```
/// use numpy::PyArray1;
/// use pyo3::{types::IntoPyDict, Python};
///
/// Python::with_gil(|py| {
/// let array = PyArray1::arange(py, 0, 10, 1);
/// assert!(array.is_contiguous());
///
/// let view = py
/// .eval("array[::2]", None, Some([("array", array)].into_py_dict(py)))
/// .unwrap()
/// .downcast::<PyArray1<i32>>()
/// .unwrap();
/// assert!(!view.is_contiguous());
/// });
/// ```
fn is_contiguous(&self) -> bool {
unsafe {
check_flags(
&*self.as_array_ptr(),
npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS,
)
}
}

/// Returns `true` if the internal data of the array is Fortran-style/column-major contiguous.
fn is_fortran_contiguous(&self) -> bool {
unsafe { check_flags(&*self.as_array_ptr(), npyffi::NPY_ARRAY_F_CONTIGUOUS) }
}

/// Returns `true` if the internal data of the array is C-style/row-major contiguous.
fn is_c_contiguous(&self) -> bool {
unsafe { check_flags(&*self.as_array_ptr(), npyffi::NPY_ARRAY_C_CONTIGUOUS) }
}

/// Returns the number of dimensions of the array.
///
/// See also [`ndarray.ndim`][ndarray-ndim] and [`PyArray_NDIM`][PyArray_NDIM].
///
/// # Example
///
/// ```
/// use numpy::PyArray3;
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
///
/// assert_eq!(arr.ndim(), 3);
/// });
/// ```
///
/// [ndarray-ndim]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ndim.html
/// [PyArray_NDIM]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_NDIM
#[inline]
fn ndim(&self) -> usize {
unsafe { (*self.as_array_ptr()).nd as usize }
}

/// Returns a slice indicating how many bytes to advance when iterating along each axis.
///
/// See also [`ndarray.strides`][ndarray-strides] and [`PyArray_STRIDES`][PyArray_STRIDES].
///
/// # Example
///
/// ```
/// use numpy::PyArray3;
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
///
/// assert_eq!(arr.strides(), &[240, 48, 8]);
/// });
/// ```
/// [ndarray-strides]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
/// [PyArray_STRIDES]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_STRIDES
#[inline]
fn strides(&self) -> &[isize] {
let n = self.ndim();
if n == 0 {
cold();
return &[];
}
let ptr = self.as_array_ptr();
unsafe {
let p = (*ptr).strides;
slice::from_raw_parts(p, n)
}
}

/// Returns a slice which contains dimmensions of the array.
///
/// See also [`ndarray.shape`][ndaray-shape] and [`PyArray_DIMS`][PyArray_DIMS].
///
/// # Example
///
/// ```
/// use numpy::PyArray3;
/// use pyo3::Python;
///
/// Python::with_gil(|py| {
/// let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
///
/// assert_eq!(arr.shape(), &[4, 5, 6]);
/// });
/// ```
///
/// [ndarray-shape]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html
/// [PyArray_DIMS]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DIMS
#[inline]
fn shape(&self) -> &[usize] {
let n = self.ndim();
if n == 0 {
cold();
return &[];
}
let ptr = self.as_array_ptr();
unsafe {
let p = (*ptr).dimensions as *mut usize;
slice::from_raw_parts(p, n)
}
}

/// Calculates the total number of elements in the array.
fn len(&self) -> usize {
self.shape().iter().product()
}

/// Returns `true` if the there are no elements in the array.
fn is_empty(&self) -> bool {
self.shape().iter().any(|dim| *dim == 0)
}
}

fn check_flags(obj: &npyffi::PyArrayObject, flags: i32) -> bool {
obj.flags & flags != 0
}

impl<'py> PyUntypedArrayMethods<'py> for Bound<'py, PyUntypedArray> {
#[inline]
fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject {
self.as_ptr().cast()
}

fn dtype(&self) -> Bound<'py, PyArrayDescr> {
unsafe {
let descr_ptr = (*self.as_array_ptr()).descr;
Bound::from_borrowed_ptr(self.py(), descr_ptr.cast()).downcast_into_unchecked()
}
}
}

// We won't be able to provide a `Deref` impl from `Bound<'_, PyArray<T, D>>` to
// `Bound<'_, PyUntypedArray>`, so this seems to be the next best thing to do
impl<'py, T, D> PyUntypedArrayMethods<'py> for Bound<'py, PyArray<T, D>> {
#[inline]
fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject {
self.as_untyped().as_array_ptr()
}

#[inline]
fn dtype(&self) -> Bound<'py, PyArrayDescr> {
self.as_untyped().dtype()
}
}

mod sealed {
use super::{PyArray, PyUntypedArray};

pub trait Sealed {}

impl Sealed for pyo3::Bound<'_, PyUntypedArray> {}
impl<T, D> Sealed for pyo3::Bound<'_, PyArray<T, D>> {}
}

0 comments on commit 9517ed9

Please sign in to comment.