Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecate inner, dot and einsum #421

Merged
merged 1 commit into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ pub use crate::dtype::{
pub use crate::error::{BorrowError, FromVecError, NotContiguousError};
pub use crate::npyffi::{PY_ARRAY_API, PY_UFUNC_API};
pub use crate::strings::{PyFixedString, PyFixedUnicode};
#[allow(deprecated)]
pub use crate::sum_products::{dot, einsum, inner};
pub use crate::sum_products::{dot_bound, einsum_bound, inner_bound};
pub use crate::untyped_array::{PyUntypedArray, PyUntypedArrayMethods};

pub use ndarray::{array, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
Expand Down
143 changes: 110 additions & 33 deletions src/sum_products.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::ptr::null_mut;

use ndarray::{Dimension, IxDyn};
use pyo3::types::PyAnyMethods;
use pyo3::{AsPyPointer, Bound, FromPyObject, PyNativeType, PyResult};
use pyo3::{Borrowed, Bound, FromPyObject, PyNativeType, PyResult};

use crate::array::PyArray;
use crate::dtype::Element;
Expand All @@ -20,8 +20,33 @@ where
{
}

impl<'py, T, D> ArrayOrScalar<'py, T> for Bound<'py, PyArray<T, D>>
where
T: Element,
D: Dimension,
{
}

impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}

/// Deprecated form of [`inner_bound`]
#[deprecated(
since = "0.21.0",
note = "will be replaced by `inner_bound` in the future"
)]
pub fn inner<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
) -> PyResult<OUT>
where
T: Element,
DIN1: Dimension,
DIN2: Dimension,
OUT: ArrayOrScalar<'py, T>,
{
inner_bound(&array1.as_borrowed(), &array2.as_borrowed())
}

/// Return the inner product of two arrays.
///
/// [NumPy's documentation][inner] has the details.
Expand All @@ -31,33 +56,33 @@ impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}
/// Note that this function can either return a scalar...
///
/// ```
/// use pyo3::Python;
/// use numpy::{inner, pyarray, PyArray0};
/// use pyo3::{Python, PyNativeType};
/// use numpy::{inner_bound, pyarray, PyArray0};
///
/// Python::with_gil(|py| {
/// let vector = pyarray![py, 1.0, 2.0, 3.0];
/// let result: f64 = inner(vector, vector).unwrap();
/// let vector = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
/// let result: f64 = inner_bound(&vector, &vector).unwrap();
/// assert_eq!(result, 14.0);
/// });
/// ```
///
/// ...or an array depending on its arguments.
///
/// ```
/// use pyo3::Python;
/// use numpy::{inner, pyarray, PyArray0};
/// use pyo3::{Python, Bound, PyNativeType};
/// use numpy::{inner_bound, pyarray, PyArray0, PyArray0Methods};
///
/// Python::with_gil(|py| {
/// let vector = pyarray![py, 1, 2, 3];
/// let result: &PyArray0<_> = inner(vector, vector).unwrap();
/// let vector = pyarray![py, 1, 2, 3].as_borrowed();
/// let result: Bound<'_, PyArray0<_>> = inner_bound(&vector, &vector).unwrap();
/// assert_eq!(result.item(), 14);
/// });
/// ```
///
/// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
pub fn inner<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
pub fn inner_bound<'py, T, DIN1, DIN2, OUT>(
array1: &Bound<'py, PyArray<T, DIN1>>,
array2: &Bound<'py, PyArray<T, DIN2>>,
) -> PyResult<OUT>
where
T: Element,
Expand All @@ -73,6 +98,24 @@ where
obj.extract()
}

/// Deprecated form of [`dot_bound`]
#[deprecated(
since = "0.21.0",
note = "will be replaced by `dot_bound` in the future"
)]
pub fn dot<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
) -> PyResult<OUT>
where
T: Element,
DIN1: Dimension,
DIN2: Dimension,
OUT: ArrayOrScalar<'py, T>,
{
dot_bound(&array1.as_borrowed(), &array2.as_borrowed())
}

/// Return the dot product of two arrays.
///
/// [NumPy's documentation][dot] has the details.
Expand All @@ -82,15 +125,15 @@ where
/// Note that this function can either return an array...
///
/// ```
/// use pyo3::Python;
/// use pyo3::{Python, Bound, PyNativeType};
/// use ndarray::array;
/// use numpy::{dot, pyarray, PyArray2};
/// use numpy::{dot_bound, pyarray, PyArray2, PyArrayMethods};
///
/// Python::with_gil(|py| {
/// let matrix = pyarray![py, [1, 0], [0, 1]];
/// let another_matrix = pyarray![py, [4, 1], [2, 2]];
/// let matrix = pyarray![py, [1, 0], [0, 1]].as_borrowed();
/// let another_matrix = pyarray![py, [4, 1], [2, 2]].as_borrowed();
///
/// let result: &PyArray2<_> = numpy::dot(matrix, another_matrix).unwrap();
/// let result: Bound<'_, PyArray2<_>> = dot_bound(&matrix, &another_matrix).unwrap();
///
/// assert_eq!(
/// result.readonly().as_array(),
Expand All @@ -102,20 +145,20 @@ where
/// ...or a scalar depending on its arguments.
///
/// ```
/// use pyo3::Python;
/// use numpy::{dot, pyarray, PyArray0};
/// use pyo3::{Python, PyNativeType};
/// use numpy::{dot_bound, pyarray, PyArray0};
///
/// Python::with_gil(|py| {
/// let vector = pyarray![py, 1.0, 2.0, 3.0];
/// let result: f64 = dot(vector, vector).unwrap();
/// let vector = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
/// let result: f64 = dot_bound(&vector, &vector).unwrap();
/// assert_eq!(result, 14.0);
/// });
/// ```
///
/// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
pub fn dot<'py, T, DIN1, DIN2, OUT>(
array1: &'py PyArray<T, DIN1>,
array2: &'py PyArray<T, DIN2>,
pub fn dot_bound<'py, T, DIN1, DIN2, OUT>(
array1: &Bound<'py, PyArray<T, DIN1>>,
array2: &Bound<'py, PyArray<T, DIN2>>,
) -> PyResult<OUT>
where
T: Element,
Expand All @@ -131,10 +174,30 @@ where
obj.extract()
}

/// Deprecated form of [`einsum_bound`]
#[deprecated(
since = "0.21.0",
note = "will be replaced by `einsum_bound` in the future"
)]
pub fn einsum<'py, T, OUT>(subscripts: &str, arrays: &[&'py PyArray<T, IxDyn>]) -> PyResult<OUT>
where
T: Element,
OUT: ArrayOrScalar<'py, T>,
{
// Safety: &PyArray<T, IxDyn> has the same size and layout in memory as
// Borrowed<'_, '_, PyArray<T, IxDyn>>
einsum_bound(subscripts, unsafe {
std::slice::from_raw_parts(arrays.as_ptr().cast(), arrays.len())
})
}

/// Return the Einstein summation convention of given tensors.
///
/// This is usually invoked via the the [`einsum!`][crate::einsum!] macro.
pub fn einsum<'py, T, OUT>(subscripts: &str, arrays: &[&'py PyArray<T, IxDyn>]) -> PyResult<OUT>
pub fn einsum_bound<'py, T, OUT>(
subscripts: &str,
arrays: &[Borrowed<'_, 'py, PyArray<T, IxDyn>>],
) -> PyResult<OUT>
where
T: Element,
OUT: ArrayOrScalar<'py, T>,
Expand All @@ -161,6 +224,20 @@ where
obj.extract()
}

/// Deprecated form of [`einsum_bound!`]
#[deprecated(
since = "0.21.0",
note = "will be replaced by `einsum_bound!` in the future"
)]
#[macro_export]
macro_rules! einsum {
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
use pyo3::PyNativeType;
let arrays = [$($array.to_dyn().as_borrowed(),)+];
$crate::einsum_bound(concat!($subscripts, "\0"), &arrays)
}};
}

/// Return the Einstein summation convention of given tensors.
///
/// For more about the Einstein summation convention, please refer to
Expand All @@ -169,15 +246,15 @@ where
/// # Example
///
/// ```
/// use pyo3::Python;
/// use pyo3::{Python, Bound, PyNativeType};
/// use ndarray::array;
/// use numpy::{einsum, pyarray, PyArray, PyArray2, PyArrayMethods};
/// use numpy::{einsum_bound, pyarray, PyArray, PyArray2, PyArrayMethods};
///
/// Python::with_gil(|py| {
/// let tensor = PyArray::arange_bound(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap().into_gil_ref();
/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
/// let tensor = PyArray::arange_bound(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]].as_borrowed();
///
/// let result: &PyArray2<_> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
/// let result: Bound<'_, PyArray2<_>> = einsum_bound!("ijk,ji->ik", tensor, another_tensor).unwrap();
///
/// assert_eq!(
/// result.readonly().as_array(),
Expand All @@ -188,9 +265,9 @@ where
///
/// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
#[macro_export]
macro_rules! einsum {
macro_rules! einsum_bound {
($subscripts:literal $(,$array:ident)+ $(,)*) => {{
let arrays = [$($array.to_dyn(),)+];
$crate::einsum(concat!($subscripts, "\0"), &arrays)
let arrays = [$($array.to_dyn().as_borrowed(),)+];
$crate::einsum_bound(concat!($subscripts, "\0"), &arrays)
}};
}
72 changes: 36 additions & 36 deletions tests/sum_products.rs
Original file line number Diff line number Diff line change
@@ -1,54 +1,55 @@
use numpy::{array, dot, einsum, inner, pyarray, PyArray0, PyArray1, PyArray2, PyArrayMethods};
use pyo3::Python;
use numpy::prelude::*;
use numpy::{array, dot_bound, einsum_bound, inner_bound, pyarray, PyArray0, PyArray1, PyArray2};
use pyo3::{Bound, PyNativeType, Python};

#[test]
fn test_dot() {
Python::with_gil(|py| {
let a = pyarray![py, [1, 0], [0, 1]];
let b = pyarray![py, [4, 1], [2, 2]];
let c: &PyArray2<_> = dot(a, b).unwrap();
let a = pyarray![py, [1, 0], [0, 1]].as_borrowed();
let b = pyarray![py, [4, 1], [2, 2]].as_borrowed();
let c: Bound<'_, PyArray2<_>> = dot_bound(&a, &b).unwrap();
assert_eq!(c.readonly().as_array(), array![[4, 1], [2, 2]]);

let a = pyarray![py, 1, 2, 3];
let err = dot::<_, _, _, &PyArray2<_>>(a, b).unwrap_err();
let a = pyarray![py, 1, 2, 3].as_borrowed();
let err = dot_bound::<_, _, _, Bound<'_, PyArray2<_>>>(&a, &b).unwrap_err();
assert!(err.to_string().contains("not aligned"), "{}", err);

let a = pyarray![py, 1, 2, 3];
let b = pyarray![py, 0, 1, 0];
let c: &PyArray0<_> = dot(a, b).unwrap();
let a = pyarray![py, 1, 2, 3].as_borrowed();
let b = pyarray![py, 0, 1, 0].as_borrowed();
let c: Bound<'_, PyArray0<_>> = dot_bound(&a, &b).unwrap();
assert_eq!(c.item(), 2);
let c: i32 = dot(a, b).unwrap();
let c: i32 = dot_bound(&a, &b).unwrap();
assert_eq!(c, 2);

let a = pyarray![py, 1.0, 2.0, 3.0];
let b = pyarray![py, 0.0, 0.0, 0.0];
let c: f64 = dot(a, b).unwrap();
let a = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
let b = pyarray![py, 0.0, 0.0, 0.0].as_borrowed();
let c: f64 = dot_bound(&a, &b).unwrap();
assert_eq!(c, 0.0);
});
}

#[test]
fn test_inner() {
Python::with_gil(|py| {
let a = pyarray![py, 1, 2, 3];
let b = pyarray![py, 0, 1, 0];
let c: &PyArray0<_> = inner(a, b).unwrap();
let a = pyarray![py, 1, 2, 3].as_borrowed();
let b = pyarray![py, 0, 1, 0].as_borrowed();
let c: Bound<'_, PyArray0<_>> = inner_bound(&a, &b).unwrap();
assert_eq!(c.item(), 2);
let c: i32 = inner(a, b).unwrap();
let c: i32 = inner_bound(&a, &b).unwrap();
assert_eq!(c, 2);

let a = pyarray![py, 1.0, 2.0, 3.0];
let b = pyarray![py, 0.0, 0.0, 0.0];
let c: f64 = inner(a, b).unwrap();
let a = pyarray![py, 1.0, 2.0, 3.0].as_borrowed();
let b = pyarray![py, 0.0, 0.0, 0.0].as_borrowed();
let c: f64 = inner_bound(&a, &b).unwrap();
assert_eq!(c, 0.0);

let a = pyarray![py, [1, 0], [0, 1]];
let b = pyarray![py, [4, 1], [2, 2]];
let c: &PyArray2<_> = inner(a, b).unwrap();
let a = pyarray![py, [1, 0], [0, 1]].as_borrowed();
let b = pyarray![py, [4, 1], [2, 2]].as_borrowed();
let c: Bound<'_, PyArray2<_>> = inner_bound(&a, &b).unwrap();
assert_eq!(c.readonly().as_array(), array![[4, 2], [1, 2]]);

let a = pyarray![py, 1, 2, 3];
let err = inner::<_, _, _, &PyArray2<_>>(a, b).unwrap_err();
let a = pyarray![py, 1, 2, 3].as_borrowed();
let err = inner_bound::<_, _, _, Bound<'_, PyArray2<_>>>(&a, &b).unwrap_err();
assert!(err.to_string().contains("not aligned"), "{}", err);
});
}
Expand All @@ -58,27 +59,26 @@ fn test_einsum() {
Python::with_gil(|py| {
let a = PyArray1::<i32>::arange_bound(py, 0, 25, 1)
.reshape([5, 5])
.unwrap()
.into_gil_ref();
let b = pyarray![py, 0, 1, 2, 3, 4];
let c = pyarray![py, [0, 1, 2], [3, 4, 5]];
.unwrap();
let b = pyarray![py, 0, 1, 2, 3, 4].as_borrowed();
let c = pyarray![py, [0, 1, 2], [3, 4, 5]].as_borrowed();

let d: &PyArray0<_> = einsum!("ii", a).unwrap();
let d: Bound<'_, PyArray0<_>> = einsum_bound!("ii", a).unwrap();
assert_eq!(d.item(), 60);

let d: i32 = einsum!("ii", a).unwrap();
let d: i32 = einsum_bound!("ii", a).unwrap();
assert_eq!(d, 60);

let d: &PyArray1<_> = einsum!("ii->i", a).unwrap();
let d: Bound<'_, PyArray1<_>> = einsum_bound!("ii->i", a).unwrap();
assert_eq!(d.readonly().as_array(), array![0, 6, 12, 18, 24]);

let d: &PyArray1<_> = einsum!("ij->i", a).unwrap();
let d: Bound<'_, PyArray1<_>> = einsum_bound!("ij->i", a).unwrap();
assert_eq!(d.readonly().as_array(), array![10, 35, 60, 85, 110]);

let d: &PyArray2<_> = einsum!("ji", c).unwrap();
let d: Bound<'_, PyArray2<_>> = einsum_bound!("ji", c).unwrap();
assert_eq!(d.readonly().as_array(), array![[0, 3], [1, 4], [2, 5]]);

let d: &PyArray1<_> = einsum!("ij,j", a, b).unwrap();
let d: Bound<'_, PyArray1<_>> = einsum_bound!("ij,j", a, b).unwrap();
assert_eq!(d.readonly().as_array(), array![30, 80, 130, 180, 230]);
});
}
Loading