Skip to content

Commit

Permalink
convert PyArrayDescr to Bound API
Browse files Browse the repository at this point in the history
  • Loading branch information
Icxolu authored and adamreichold committed Mar 23, 2024
1 parent 9517ed9 commit fad4e18
Show file tree
Hide file tree
Showing 8 changed files with 421 additions and 125 deletions.
20 changes: 10 additions & 10 deletions src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use pyo3::{
use crate::borrow::{PyReadonlyArray, PyReadwriteArray};
use crate::cold;
use crate::convert::{ArrayExt, IntoPyArray, NpyIndex, ToNpyDims, ToPyArray};
use crate::dtype::Element;
use crate::dtype::{Element, PyArrayDescrMethods};
use crate::error::{
BorrowError, DimensionalityError, FromVecError, IgnoreError, NotContiguousError, TypeError,
DIMENSIONALITY_MISMATCH_ERR, MAX_DIMENSIONALITY_ERR,
Expand Down Expand Up @@ -278,10 +278,10 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
}

// Check if the element type matches `T`.
let src_dtype = arr_gil_ref.dtype();
let dst_dtype = T::get_dtype(ob.py());
if !src_dtype.is_equiv_to(dst_dtype) {
return Err(TypeError::new(src_dtype, dst_dtype).into());
let src_dtype = array.dtype();
let dst_dtype = T::get_dtype_bound(ob.py());
if !src_dtype.is_equiv_to(&dst_dtype) {
return Err(TypeError::new(src_dtype.into_gil_ref(), dst_dtype.into_gil_ref()).into());
}

Ok(array)
Expand Down Expand Up @@ -354,7 +354,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
py,
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
T::get_dtype(py).into_dtype_ptr(),
T::get_dtype_bound(py).into_dtype_ptr(),
dims.ndim_cint(),
dims.as_dims_ptr(),
strides as *mut npy_intp, // strides
Expand All @@ -380,7 +380,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
let ptr = PY_ARRAY_API.PyArray_NewFromDescr(
py,
PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type),
T::get_dtype(py).into_dtype_ptr(),
T::get_dtype_bound(py).into_dtype_ptr(),
dims.ndim_cint(),
dims.as_dims_ptr(),
strides as *mut npy_intp, // strides
Expand Down Expand Up @@ -500,7 +500,7 @@ impl<T: Element, D: Dimension> PyArray<T, D> {
py,
dims.ndim_cint(),
dims.as_dims_ptr(),
T::get_dtype(py).into_dtype_ptr(),
T::get_dtype_bound(py).into_dtype_ptr(),
if is_fortran { -1 } else { 0 },
);
Self::from_owned_ptr(py, ptr)
Expand Down Expand Up @@ -1315,7 +1315,7 @@ impl<T: Element, D> PyArray<T, D> {
PY_ARRAY_API.PyArray_CastToType(
self.py(),
self.as_array_ptr(),
U::get_dtype(self.py()).into_dtype_ptr(),
U::get_dtype_bound(self.py()).into_dtype_ptr(),
if is_fortran { -1 } else { 0 },
)
};
Expand Down Expand Up @@ -1461,7 +1461,7 @@ impl<T: Element + AsPrimitive<f64>> PyArray<T, Ix1> {
start.as_(),
stop.as_(),
step.as_(),
T::get_dtype(py).num(),
T::get_dtype_bound(py).num(),
);
Self::from_owned_ptr(py, ptr)
}
Expand Down
2 changes: 1 addition & 1 deletion src/array_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ where

let kwargs = if C::VAL {
let kwargs = PyDict::new(py);
kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
kwargs.set_item(intern!(py, "dtype"), T::get_dtype_bound(py))?;
Some(kwargs)
} else {
None
Expand Down
12 changes: 6 additions & 6 deletions src/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ use std::fmt;
use std::hash::Hash;
use std::marker::PhantomData;

use pyo3::{sync::GILProtected, Py, Python};
use pyo3::{sync::GILProtected, Bound, Py, Python};
use rustc_hash::FxHashMap;

use crate::dtype::{Element, PyArrayDescr};
use crate::dtype::{Element, PyArrayDescr, PyArrayDescrMethods};
use crate::npyffi::{PyArray_DatetimeDTypeMetaData, NPY_DATETIMEUNIT, NPY_TYPES};

/// Represents the [datetime units][datetime-units] supported by NumPy
Expand Down Expand Up @@ -156,7 +156,7 @@ impl<U: Unit> From<Datetime<U>> for i64 {
unsafe impl<U: Unit> Element for Datetime<U> {
const IS_COPY: bool = true;

fn get_dtype<'py>(py: Python<'py>) -> &'py PyArrayDescr {
fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
static DTYPES: TypeDescriptors = unsafe { TypeDescriptors::new(NPY_TYPES::NPY_DATETIME) };

DTYPES.from_unit(py, U::UNIT)
Expand Down Expand Up @@ -191,7 +191,7 @@ impl<U: Unit> From<Timedelta<U>> for i64 {
unsafe impl<U: Unit> Element for Timedelta<U> {
const IS_COPY: bool = true;

fn get_dtype<'py>(py: Python<'py>) -> &'py PyArrayDescr {
fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
static DTYPES: TypeDescriptors = unsafe { TypeDescriptors::new(NPY_TYPES::NPY_TIMEDELTA) };

DTYPES.from_unit(py, U::UNIT)
Expand Down Expand Up @@ -220,7 +220,7 @@ impl TypeDescriptors {
}

#[allow(clippy::wrong_self_convention)]
fn from_unit<'py>(&'py self, py: Python<'py>, unit: NPY_DATETIMEUNIT) -> &'py PyArrayDescr {
fn from_unit<'py>(&self, py: Python<'py>, unit: NPY_DATETIMEUNIT) -> Bound<'py, PyArrayDescr> {
let mut dtypes = self.dtypes.get(py).borrow_mut();

let dtype = match dtypes.get_or_insert_with(Default::default).entry(unit) {
Expand All @@ -241,7 +241,7 @@ impl TypeDescriptors {
}
};

dtype.clone().into_ref(py)
dtype.clone().into_bound(py)
}
}

Expand Down
Loading

0 comments on commit fad4e18

Please sign in to comment.