diff --git a/src/array.rs b/src/array.rs index a57567996..e000c9201 100644 --- a/src/array.rs +++ b/src/array.rs @@ -33,7 +33,7 @@ use crate::error::{ }; use crate::npyffi::{self, npy_intp, NPY_ORDER, PY_ARRAY_API}; use crate::slice_container::PySliceContainer; -use crate::untyped_array::PyUntypedArray; +use crate::untyped_array::{PyUntypedArray, PyUntypedArrayMethods}; /// A safe, statically-typed wrapper for NumPy's [`ndarray`][ndarray] class. /// @@ -1480,6 +1480,20 @@ unsafe fn clone_elements(elems: &[T], data_ptr: &mut *mut T) { } } +/// Implementation of functionality for [`PyArray`]. +#[doc(alias = "PyArray")] +pub trait PyArrayMethods<'py, T, D>: PyUntypedArrayMethods<'py> { + /// Access an untyped representation of this array. + fn as_untyped(&self) -> &Bound<'py, PyUntypedArray>; +} + +impl<'py, T, D> PyArrayMethods<'py, T, D> for Bound<'py, PyArray> { + #[inline(always)] + fn as_untyped(&self) -> &Bound<'py, PyUntypedArray> { + unsafe { self.downcast_unchecked() } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index 2b11bc3d5..8e538366e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,7 +93,7 @@ pub use nalgebra; pub use crate::array::{ get_array_module, PyArray, PyArray0, PyArray1, PyArray2, PyArray3, PyArray4, PyArray5, - PyArray6, PyArrayDyn, + PyArray6, PyArrayDyn, PyArrayMethods, }; pub use crate::array_like::{ AllowTypeChange, PyArrayLike, PyArrayLike0, PyArrayLike1, PyArrayLike2, PyArrayLike3, @@ -111,7 +111,7 @@ pub use crate::error::{BorrowError, FromVecError, NotContiguousError}; pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API}; pub use crate::strings::{PyFixedString, PyFixedUnicode}; pub use crate::sum_products::{dot, einsum, inner}; -pub use crate::untyped_array::PyUntypedArray; +pub use crate::untyped_array::{PyUntypedArray, PyUntypedArrayMethods}; pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; diff --git a/src/untyped_array.rs b/src/untyped_array.rs index 089afdd10..fa5ae08ee 100644 --- a/src/untyped_array.rs +++ b/src/untyped_array.rs @@ -1,13 +1,14 @@ //! Safe, untyped interface for NumPy's [N-dimensional arrays][ndarray] //! //! [ndarray]: https://numpy.org/doc/stable/reference/arrays.ndarray.html -use std::{os::raw::c_int, slice}; +use std::slice; use pyo3::{ - ffi, pyobject_native_type_extract, pyobject_native_type_named, AsPyPointer, IntoPy, PyAny, - PyNativeType, PyObject, PyTypeInfo, Python, + ffi, pyobject_native_type_extract, pyobject_native_type_named, types::PyAnyMethods, + AsPyPointer, Bound, IntoPy, PyAny, PyNativeType, PyObject, PyTypeInfo, Python, }; +use crate::array::{PyArray, PyArrayMethods}; use crate::cold; use crate::dtype::PyArrayDescr; use crate::npyffi; @@ -68,7 +69,7 @@ unsafe impl PyTypeInfo for PyUntypedArray { unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) } } - fn is_type_of(ob: &PyAny) -> bool { + fn is_type_of_bound(ob: &Bound<'_, PyAny>) -> bool { unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) != 0 } } } @@ -87,7 +88,7 @@ impl PyUntypedArray { /// Returns a raw pointer to the underlying [`PyArrayObject`][npyffi::PyArrayObject]. #[inline] pub fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject { - self.as_ptr() as _ + self.as_borrowed().as_array_ptr() } /// Returns the `dtype` of the array. @@ -109,16 +110,9 @@ impl PyUntypedArray { /// /// [ndarray-dtype]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dtype.html /// [PyArray_DTYPE]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DTYPE + #[inline] pub fn dtype(&self) -> &PyArrayDescr { - unsafe { - let descr_ptr = (*self.as_array_ptr()).descr; - self.py().from_borrowed_ptr(descr_ptr as _) - } - } - - #[inline(always)] - pub(crate) fn check_flags(&self, flags: c_int) -> bool { - unsafe { (*self.as_array_ptr()).flags & flags != 0 } + self.as_borrowed().dtype().into_gil_ref() } /// Returns `true` if the internal data of the array is contiguous, @@ -142,18 +136,21 @@ impl PyUntypedArray { /// assert!(!view.is_contiguous()); /// }); /// ``` + #[inline] pub fn is_contiguous(&self) -> bool { - self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS) + self.as_borrowed().is_contiguous() } /// Returns `true` if the internal data of the array is Fortran-style/column-major contiguous. + #[inline] pub fn is_fortran_contiguous(&self) -> bool { - self.check_flags(npyffi::NPY_ARRAY_F_CONTIGUOUS) + self.as_borrowed().is_fortran_contiguous() } /// Returns `true` if the internal data of the array is C-style/row-major contiguous. + #[inline] pub fn is_c_contiguous(&self) -> bool { - self.check_flags(npyffi::NPY_ARRAY_C_CONTIGUOUS) + self.as_borrowed().is_c_contiguous() } /// Returns the number of dimensions of the array. @@ -177,7 +174,7 @@ impl PyUntypedArray { /// [PyArray_NDIM]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_NDIM #[inline] pub fn ndim(&self) -> usize { - unsafe { (*self.as_array_ptr()).nd as usize } + self.as_borrowed().ndim() } /// Returns a slice indicating how many bytes to advance when iterating along each axis. @@ -246,12 +243,222 @@ impl PyUntypedArray { } /// Calculates the total number of elements in the array. + #[inline] pub fn len(&self) -> usize { - self.shape().iter().product() + self.as_borrowed().len() } /// Returns `true` if the there are no elements in the array. + #[inline] pub fn is_empty(&self) -> bool { + self.as_borrowed().is_empty() + } +} + +/// Implementation of functionality for [`PyUntypedArray`]. +#[doc(alias = "PyUntypedArray")] +pub trait PyUntypedArrayMethods<'py>: sealed::Sealed { + /// Returns a raw pointer to the underlying [`PyArrayObject`][npyffi::PyArrayObject]. + fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject; + + /// Returns the `dtype` of the array. + /// + /// See also [`ndarray.dtype`][ndarray-dtype] and [`PyArray_DTYPE`][PyArray_DTYPE]. + /// + /// # Example + /// + /// ``` + /// use numpy::{dtype, PyArray}; + /// use pyo3::Python; + /// + /// Python::with_gil(|py| { + /// let array = PyArray::from_vec(py, vec![1_i32, 2, 3]); + /// + /// assert!(array.dtype().is_equiv_to(dtype::(py))); + /// }); + /// ``` + /// + /// [ndarray-dtype]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dtype.html + /// [PyArray_DTYPE]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DTYPE + fn dtype(&self) -> Bound<'py, PyArrayDescr>; + + /// Returns `true` if the internal data of the array is contiguous, + /// indepedently of whether C-style/row-major or Fortran-style/column-major. + /// + /// # Example + /// + /// ``` + /// use numpy::PyArray1; + /// use pyo3::{types::IntoPyDict, Python}; + /// + /// Python::with_gil(|py| { + /// let array = PyArray1::arange(py, 0, 10, 1); + /// assert!(array.is_contiguous()); + /// + /// let view = py + /// .eval("array[::2]", None, Some([("array", array)].into_py_dict(py))) + /// .unwrap() + /// .downcast::>() + /// .unwrap(); + /// assert!(!view.is_contiguous()); + /// }); + /// ``` + fn is_contiguous(&self) -> bool { + unsafe { + check_flags( + &*self.as_array_ptr(), + npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS, + ) + } + } + + /// Returns `true` if the internal data of the array is Fortran-style/column-major contiguous. + fn is_fortran_contiguous(&self) -> bool { + unsafe { check_flags(&*self.as_array_ptr(), npyffi::NPY_ARRAY_F_CONTIGUOUS) } + } + + /// Returns `true` if the internal data of the array is C-style/row-major contiguous. + fn is_c_contiguous(&self) -> bool { + unsafe { check_flags(&*self.as_array_ptr(), npyffi::NPY_ARRAY_C_CONTIGUOUS) } + } + + /// Returns the number of dimensions of the array. + /// + /// See also [`ndarray.ndim`][ndarray-ndim] and [`PyArray_NDIM`][PyArray_NDIM]. + /// + /// # Example + /// + /// ``` + /// use numpy::PyArray3; + /// use pyo3::Python; + /// + /// Python::with_gil(|py| { + /// let arr = PyArray3::::zeros(py, [4, 5, 6], false); + /// + /// assert_eq!(arr.ndim(), 3); + /// }); + /// ``` + /// + /// [ndarray-ndim]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ndim.html + /// [PyArray_NDIM]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_NDIM + #[inline] + fn ndim(&self) -> usize { + unsafe { (*self.as_array_ptr()).nd as usize } + } + + /// Returns a slice indicating how many bytes to advance when iterating along each axis. + /// + /// See also [`ndarray.strides`][ndarray-strides] and [`PyArray_STRIDES`][PyArray_STRIDES]. + /// + /// # Example + /// + /// ``` + /// use numpy::PyArray3; + /// use pyo3::Python; + /// + /// Python::with_gil(|py| { + /// let arr = PyArray3::::zeros(py, [4, 5, 6], false); + /// + /// assert_eq!(arr.strides(), &[240, 48, 8]); + /// }); + /// ``` + /// [ndarray-strides]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html + /// [PyArray_STRIDES]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_STRIDES + #[inline] + fn strides(&self) -> &[isize] { + let n = self.ndim(); + if n == 0 { + cold(); + return &[]; + } + let ptr = self.as_array_ptr(); + unsafe { + let p = (*ptr).strides; + slice::from_raw_parts(p, n) + } + } + + /// Returns a slice which contains dimmensions of the array. + /// + /// See also [`ndarray.shape`][ndaray-shape] and [`PyArray_DIMS`][PyArray_DIMS]. + /// + /// # Example + /// + /// ``` + /// use numpy::PyArray3; + /// use pyo3::Python; + /// + /// Python::with_gil(|py| { + /// let arr = PyArray3::::zeros(py, [4, 5, 6], false); + /// + /// assert_eq!(arr.shape(), &[4, 5, 6]); + /// }); + /// ``` + /// + /// [ndarray-shape]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html + /// [PyArray_DIMS]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DIMS + #[inline] + fn shape(&self) -> &[usize] { + let n = self.ndim(); + if n == 0 { + cold(); + return &[]; + } + let ptr = self.as_array_ptr(); + unsafe { + let p = (*ptr).dimensions as *mut usize; + slice::from_raw_parts(p, n) + } + } + + /// Calculates the total number of elements in the array. + fn len(&self) -> usize { + self.shape().iter().product() + } + + /// Returns `true` if the there are no elements in the array. + fn is_empty(&self) -> bool { self.shape().iter().any(|dim| *dim == 0) } } + +fn check_flags(obj: &npyffi::PyArrayObject, flags: i32) -> bool { + obj.flags & flags != 0 +} + +impl<'py> PyUntypedArrayMethods<'py> for Bound<'py, PyUntypedArray> { + #[inline] + fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject { + self.as_ptr().cast() + } + + fn dtype(&self) -> Bound<'py, PyArrayDescr> { + unsafe { + let descr_ptr = (*self.as_array_ptr()).descr; + Bound::from_borrowed_ptr(self.py(), descr_ptr.cast()).downcast_into_unchecked() + } + } +} + +// We won't be able to provide a `Deref` impl from `Bound<'_, PyArray>` to +// `Bound<'_, PyUntypedArray>`, so this seems to be the next best thing to do +impl<'py, T, D> PyUntypedArrayMethods<'py> for Bound<'py, PyArray> { + #[inline] + fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject { + self.as_untyped().as_array_ptr() + } + + #[inline] + fn dtype(&self) -> Bound<'py, PyArrayDescr> { + self.as_untyped().dtype() + } +} + +mod sealed { + use super::{PyArray, PyUntypedArray}; + + pub trait Sealed {} + + impl Sealed for pyo3::Bound<'_, PyUntypedArray> {} + impl Sealed for pyo3::Bound<'_, PyArray> {} +}