Skip to content

Commit

Permalink
deprecate inner, dot and einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
Icxolu committed Mar 30, 2024
1 parent 0b39d09 commit 5e304d7
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 69 deletions.
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ pub use crate::dtype::{
pub use crate::error::{BorrowError, FromVecError, NotContiguousError};
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
pub use crate::strings::{PyFixedString, PyFixedUnicode};
#[allow(deprecated)]
pub use crate::sum_products::{dot, einsum, inner};
pub use crate::sum_products::{dot_bound, einsum_bound, inner_bound};
pub use crate::untyped_array::{PyUntypedArray, PyUntypedArrayMethods};

pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
Expand Down
143 changes: 110 additions & 33 deletions src/sum_products.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::ptr::null_mut;

use ndarray::{Dimension, IxDyn};
use pyo3::types::PyAnyMethods;
use pyo3::{AsPyPointer, Bound, FromPyObject, PyNativeType, PyResult};
use pyo3::{Borrowed, Bound, FromPyObject, PyNativeType, PyResult};

use crate::array::PyArray;
use crate::dtype::Element;
Expand All @@ -20,8 +20,33 @@ where
{
}

impl<'py, T, D> ArrayOrScalar<'py, T> for Bound<'py, PyArray<T, D>>
where
T: Element,
D: Dimension,
{
}

impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}

/// Deprecated form of [`inner_bound`]
#[deprecated(
since = "0.21.0",
note = "will be replaced by `inner_bound` in the future"
)]
pub fn inner<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
) -> PyResult<OUT>
where
T: Element,
DIN1: Dimension,
DIN2: Dimension,
OUT: ArrayOrScalar<'py, T>,
{
inner_bound(&array1.as_borrowed(), &array2.as_borrowed())
}

/// Return the inner product of two arrays.
///
/// [NumPy's documentation][inner] has the details.
Expand All @@ -31,33 +56,33 @@ impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}
/// Note that this function can either return a scalar...
///
/// ```
/// use pyo3::Python;
/// use numpy::{inner, pyarray, PyArray0};
/// use pyo3::{Python, PyNativeType};
/// use numpy::{inner_bound, pyarray, PyArray0};
///
/// Python::with_gil(|py| {
/// let vector = pyarray![py, 1.0, 2.0, 3.0];
/// let result: f64 = inner(vector, vector).unwrap();
/// let vector = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
/// let result: f64 = inner_bound(&vector, &vector).unwrap();
/// assert_eq!(result, 14.0);
/// });
/// ```
///
/// ...or an array depending on its arguments.
///
/// ```
/// use pyo3::Python;
/// use numpy::{inner, pyarray, PyArray0};
/// use pyo3::{Python, Bound, PyNativeType};
/// use numpy::{inner_bound, pyarray, PyArray0, PyArrayMethods};
///
/// Python::with_gil(|py| {
/// let vector = pyarray![py, 1, 2, 3];
/// let result: &PyArray0<_> = inner(vector, vector).unwrap();
/// let vector = pyarray![py, 1, 2, 3].as_borrowed();
/// let result: Bound<'_, PyArray0<_>> = inner_bound(&vector, &vector).unwrap();
/// assert_eq!(result.item(), 14);
/// });
/// ```
///
/// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
pub fn inner<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
pub fn inner_bound<'py, T, DIN1, DIN2, OUT>(
array1: &Bound<'py, PyArray<T, DIN1>>,
array2: &Bound<'py, PyArray<T, DIN2>>,
) -> PyResult<OUT>
where
T: Element,
Expand All @@ -73,6 +98,24 @@ where
obj.extract()
}

/// Deprecated form of [`dot_bound`]
#[deprecated(
since = "0.21.0",
note = "will be replaced by `dot_bound` in the future"
)]
pub fn dot<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
) -> PyResult<OUT>
where
T: Element,
DIN1: Dimension,
DIN2: Dimension,
OUT: ArrayOrScalar<'py, T>,
{
dot_bound(&array1.as_borrowed(), &array2.as_borrowed())
}

/// Return the dot product of two arrays.
///
/// [NumPy's documentation][dot] has the details.
Expand All @@ -82,15 +125,15 @@ where
/// Note that this function can either return an array...
///
/// ```
/// use pyo3::Python;
/// use pyo3::{Python, Bound, PyNativeType};
/// use ndarray::array;
/// use numpy::{dot, pyarray, PyArray2};
/// use numpy::{dot_bound, pyarray, PyArray2, PyArrayMethods};
///
/// Python::with_gil(|py| {
/// let matrix = pyarray![py, [1, 0], [0, 1]];
/// let another_matrix = pyarray![py, [4, 1], [2, 2]];
/// let matrix = pyarray![py, [1, 0], [0, 1]].as_borrowed();
/// let another_matrix = pyarray![py, [4, 1], [2, 2]].as_borrowed();
///
/// let result: &PyArray2<_> = numpy::dot(matrix, another_matrix).unwrap();
/// let result: Bound<'_, PyArray2<_>> = dot_bound(&matrix, &another_matrix).unwrap();
///
/// assert_eq!(
/// result.readonly().as_array(),
Expand All @@ -102,20 +145,20 @@ where
/// ...or a scalar depending on its arguments.
///
/// ```
/// use pyo3::Python;
/// use numpy::{dot, pyarray, PyArray0};
/// use pyo3::{Python, PyNativeType};
/// use numpy::{dot_bound, pyarray, PyArray0};
///
/// Python::with_gil(|py| {
/// let vector = pyarray![py, 1.0, 2.0, 3.0];
/// let result: f64 = dot(vector, vector).unwrap();
/// let vector = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
/// let result: f64 = dot_bound(&vector, &vector).unwrap();
/// assert_eq!(result, 14.0);
/// });
/// ```
///
/// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
pub fn dot<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
pub fn dot_bound<'py, T, DIN1, DIN2, OUT>(
array1: &Bound<'py, PyArray<T, DIN1>>,
array2: &Bound<'py, PyArray<T, DIN2>>,
) -> PyResult<OUT>
where
T: Element,
Expand All @@ -131,10 +174,30 @@ where
obj.extract()
}

/// Deprecated form of [`einsum_bound`]
#[deprecated(
since = "0.21.0",
note = "will be replaced by `einsum_bound` in the future"
)]
pub fn einsum<'py, T, OUT>(subscripts: &str, arrays: &[&'py PyArray<T, IxDyn>]) -> PyResult<OUT>
where
T: Element,
OUT: ArrayOrScalar<'py, T>,
{
// Safety: &PyArray<T, IxDyn> has the same size and layout in memory as
// Borrowed<'_, '_, PyArray<T, IxDyn>>
einsum_bound(subscripts, unsafe {
std::slice::from_raw_parts(arrays.as_ptr().cast(), arrays.len())
})
}

/// Return the Einstein summation convention of given tensors.
///
/// This is usually invoked via the the [`einsum!`][crate::einsum!] macro.
pub fn einsum<'py, T, OUT>(subscripts: &str, arrays: &[&'py PyArray<T, IxDyn>]) -> PyResult<OUT>
pub fn einsum_bound<'py, T, OUT>(
subscripts: &str,
arrays: &[Borrowed<'_, 'py, PyArray<T, IxDyn>>],
) -> PyResult<OUT>
where
T: Element,
OUT: ArrayOrScalar<'py, T>,
Expand All @@ -161,6 +224,20 @@ where
obj.extract()
}

/// Deprecated form of [`einsum_bound!`]
#[deprecated(
since = "0.21.0",
note = "will be replaced by `einsum_bound!` in the future"
)]
#[macro_export]
macro_rules! einsum {
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
use pyo3::PyNativeType;
let arrays = [$($array.to_dyn().as_borrowed(),)+];
$crate::einsum_bound(concat!($subscripts, "\0"), &arrays)
}};
}

/// Return the Einstein summation convention of given tensors.
///
/// For more about the Einstein summation convention, please refer to
Expand All @@ -169,15 +246,15 @@ where
/// # Example
///
/// ```
/// use pyo3::Python;
/// use pyo3::{Python, Bound, PyNativeType};
/// use ndarray::array;
/// use numpy::{einsum, pyarray, PyArray, PyArray2, PyArrayMethods};
/// use numpy::{einsum_bound, pyarray, PyArray, PyArray2, PyArrayMethods};
///
/// Python::with_gil(|py| {
/// let tensor = PyArray::arange_bound(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap().into_gil_ref();
/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
/// let tensor = PyArray::arange_bound(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]].as_borrowed();
///
/// let result: &PyArray2<_> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
/// let result: Bound<'_, PyArray2<_>> = einsum_bound!("ijk,ji->ik", tensor, another_tensor).unwrap();
///
/// assert_eq!(
/// result.readonly().as_array(),
Expand All @@ -188,9 +265,9 @@ where
///
/// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
#[macro_export]
macro_rules! einsum {
macro_rules! einsum_bound {
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
let arrays = [$($array.to_dyn(),)+];
$crate::einsum(concat!($subscripts, "\0"), &arrays)
let arrays = [$($array.to_dyn().as_borrowed(),)+];
$crate::einsum_bound(concat!($subscripts, "\0"), &arrays)
}};
}
72 changes: 36 additions & 36 deletions tests/sum_products.rs
Original file line number Diff line number Diff line change
@@ -1,54 +1,55 @@
use numpy::{array, dot, einsum, inner, pyarray, PyArray0, PyArray1, PyArray2, PyArrayMethods};
use pyo3::Python;
use numpy::prelude::*;
use numpy::{array, dot_bound, einsum_bound, inner_bound, pyarray, PyArray0, PyArray1, PyArray2};
use pyo3::{Bound, PyNativeType, Python};

#[test]
fn test_dot() {
Python::with_gil(|py| {
let a = pyarray![py, [1, 0], [0, 1]];
let b = pyarray![py, [4, 1], [2, 2]];
let c: &PyArray2<_> = dot(a, b).unwrap();
let a = pyarray![py, [1, 0], [0, 1]].as_borrowed();
let b = pyarray![py, [4, 1], [2, 2]].as_borrowed();
let c: Bound<'_, PyArray2<_>> = dot_bound(&a, &b).unwrap();
assert_eq!(c.readonly().as_array(), array![[4, 1], [2, 2]]);

let a = pyarray![py, 1, 2, 3];
let err = dot::<_, _, _, &PyArray2<_>>(a, b).unwrap_err();
let a = pyarray![py, 1, 2, 3].as_borrowed();
let err = dot_bound::<_, _, _, Bound<'_, PyArray2<_>>>(&a, &b).unwrap_err();
assert!(err.to_string().contains("not aligned"), "{}", err);

let a = pyarray![py, 1, 2, 3];
let b = pyarray![py, 0, 1, 0];
let c: &PyArray0<_> = dot(a, b).unwrap();
let a = pyarray![py, 1, 2, 3].as_borrowed();
let b = pyarray![py, 0, 1, 0].as_borrowed();
let c: Bound<'_, PyArray0<_>> = dot_bound(&a, &b).unwrap();
assert_eq!(c.item(), 2);
let c: i32 = dot(a, b).unwrap();
let c: i32 = dot_bound(&a, &b).unwrap();
assert_eq!(c, 2);

let a = pyarray![py, 1.0, 2.0, 3.0];
let b = pyarray![py, 0.0, 0.0, 0.0];
let c: f64 = dot(a, b).unwrap();
let a = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
let b = pyarray![py, 0.0, 0.0, 0.0].as_borrowed();
let c: f64 = dot_bound(&a, &b).unwrap();
assert_eq!(c, 0.0);
});
}

#[test]
fn test_inner() {
Python::with_gil(|py| {
let a = pyarray![py, 1, 2, 3];
let b = pyarray![py, 0, 1, 0];
let c: &PyArray0<_> = inner(a, b).unwrap();
let a = pyarray![py, 1, 2, 3].as_borrowed();
let b = pyarray![py, 0, 1, 0].as_borrowed();
let c: Bound<'_, PyArray0<_>> = inner_bound(&a, &b).unwrap();
assert_eq!(c.item(), 2);
let c: i32 = inner(a, b).unwrap();
let c: i32 = inner_bound(&a, &b).unwrap();
assert_eq!(c, 2);

let a = pyarray![py, 1.0, 2.0, 3.0];
let b = pyarray![py, 0.0, 0.0, 0.0];
let c: f64 = inner(a, b).unwrap();
let a = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
let b = pyarray![py, 0.0, 0.0, 0.0].as_borrowed();
let c: f64 = inner_bound(&a, &b).unwrap();
assert_eq!(c, 0.0);

let a = pyarray![py, [1, 0], [0, 1]];
let b = pyarray![py, [4, 1], [2, 2]];
let c: &PyArray2<_> = inner(a, b).unwrap();
let a = pyarray![py, [1, 0], [0, 1]].as_borrowed();
let b = pyarray![py, [4, 1], [2, 2]].as_borrowed();
let c: Bound<'_, PyArray2<_>> = inner_bound(&a, &b).unwrap();
assert_eq!(c.readonly().as_array(), array![[4, 2], [1, 2]]);

let a = pyarray![py, 1, 2, 3];
let err = inner::<_, _, _, &PyArray2<_>>(a, b).unwrap_err();
let a = pyarray![py, 1, 2, 3].as_borrowed();
let err = inner_bound::<_, _, _, Bound<'_, PyArray2<_>>>(&a, &b).unwrap_err();
assert!(err.to_string().contains("not aligned"), "{}", err);
});
}
Expand All @@ -58,27 +59,26 @@ fn test_einsum() {
Python::with_gil(|py| {
let a = PyArray1::<i32>::arange_bound(py, 0, 25, 1)
.reshape([5, 5])
.unwrap()
.into_gil_ref();
let b = pyarray![py, 0, 1, 2, 3, 4];
let c = pyarray![py, [0, 1, 2], [3, 4, 5]];
.unwrap();
let b = pyarray![py, 0, 1, 2, 3, 4].as_borrowed();
let c = pyarray![py, [0, 1, 2], [3, 4, 5]].as_borrowed();

let d: &PyArray0<_> = einsum!("ii", a).unwrap();
let d: Bound<'_, PyArray0<_>> = einsum_bound!("ii", a).unwrap();
assert_eq!(d.item(), 60);

let d: i32 = einsum!("ii", a).unwrap();
let d: i32 = einsum_bound!("ii", a).unwrap();
assert_eq!(d, 60);

let d: &PyArray1<_> = einsum!("ii->i", a).unwrap();
let d: Bound<'_, PyArray1<_>> = einsum_bound!("ii->i", a).unwrap();
assert_eq!(d.readonly().as_array(), array![0, 6, 12, 18, 24]);

let d: &PyArray1<_> = einsum!("ij->i", a).unwrap();
let d: Bound<'_, PyArray1<_>> = einsum_bound!("ij->i", a).unwrap();
assert_eq!(d.readonly().as_array(), array![10, 35, 60, 85, 110]);

let d: &PyArray2<_> = einsum!("ji", c).unwrap();
let d: Bound<'_, PyArray2<_>> = einsum_bound!("ji", c).unwrap();
assert_eq!(d.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);

let d: &PyArray1<_> = einsum!("ij,j", a, b).unwrap();
let d: Bound<'_, PyArray1<_>> = einsum_bound!("ij,j", a, b).unwrap();
assert_eq!(d.readonly().as_array(), array![30, 80, 130, 180, 230]);
});
}

0 comments on commit 5e304d7

Please sign in to comment.