Skip to content

Commit

Permalink
Fix axis0 calls in reduction Python binding (#1459)
Browse files Browse the repository at this point in the history
* max and min now use MinMaxAtomicSupportFactory

These functions were using ArithmeticAtomicSupportFactory, which disables atomics for floating point types

* Resolves #1455

This issue was caused by a typo where when the `axis0` kernels
for tree and atomic reductions would be called, the `axis1` kernel
would be called instead

* Adds tests for #1455 resolution
  • Loading branch information
ndgrigorian authored Oct 27, 2023
1 parent d82f3a9 commit 02e7714
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ template <typename fnT, typename T> struct MinMaxAtomicSupportFactory
};

template <typename fnT, typename T>
struct MaxAtomicSupportFactory : public ArithmeticAtomicSupportFactory<fnT, T>
struct MaxAtomicSupportFactory : public MinMaxAtomicSupportFactory<fnT, T>
{
};

template <typename fnT, typename T>
struct MinAtomicSupportFactory : public ArithmeticAtomicSupportFactory<fnT, T>
struct MinAtomicSupportFactory : public MinMaxAtomicSupportFactory<fnT, T>
{
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,10 @@ std::pair<sycl::event, sycl::event> py_reduction_over_axis(
typename std::remove_all_extents<contig_fnT>::type;
contig_fn_ptr_T fn;
if (supports_atomics) {
fn = axis1_atomic_dispatch_table[src_typeid][dst_typeid];
fn = axis0_atomic_dispatch_table[src_typeid][dst_typeid];
}
else {
fn = axis1_temps_dispatch_table[src_typeid][dst_typeid];
fn = axis0_temps_dispatch_table[src_typeid][dst_typeid];
}
if (fn != nullptr) {
sycl::event reduction_over_axis0_contig_ev =
Expand Down Expand Up @@ -727,7 +727,7 @@ std::pair<sycl::event, sycl::event> py_tree_reduction_over_axis(
}
}
else if (mat_reduce_over_axis0) {
auto fn = axis1_temps_dispatch_table[src_typeid][dst_typeid];
auto fn = axis0_temps_dispatch_table[src_typeid][dst_typeid];
if (fn != nullptr) {
sycl::event reduction_over_axis0_contig_ev =
fn(exec_q, iter_nelems, reduction_nelems, src.get_data(),
Expand Down Expand Up @@ -929,7 +929,6 @@ std::pair<sycl::event, sycl::event> py_search_over_axis(
}

using dpctl::tensor::py_internal::simplify_iteration_space;
using dpctl::tensor::py_internal::simplify_iteration_space_1;

auto const &src_shape_vecs = src.get_shape_vector();
auto const &src_strides_vecs = src.get_strides_vector();
Expand Down
30 changes: 30 additions & 0 deletions dpctl/tests/test_tensor_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,36 @@ def test_axis0_bug():
assert dpt.all(s == expected)


def test_sum_axis1_axis0():
"""See gh-1455"""
get_queue_or_skip()

# The atomic case is checked in `test_usm_ndarray_reductions`
# This test checks the tree reduction path for correctness
x = dpt.reshape(dpt.arange(3 * 4 * 5, dtype="f4"), (3, 4, 5))

m = dpt.sum(x, axis=0)
expected = dpt.asarray(
[
[60, 63, 66, 69, 72],
[75, 78, 81, 84, 87],
[90, 93, 96, 99, 102],
[105, 108, 111, 114, 117],
],
dtype="f4",
)
tol = dpt.finfo(m.dtype).resolution
assert dpt.allclose(m, expected, atol=tol, rtol=tol)

x = dpt.flip(x, axis=2)
m = dpt.sum(x, axis=2)
expected = dpt.asarray(
[[10, 35, 60, 85], [110, 135, 160, 185], [210, 235, 260, 285]],
dtype="f4",
)
assert dpt.allclose(m, expected, atol=tol, rtol=tol)


def _any_complex(dtypes):
return any(dpt.isdtype(dpt.dtype(dt), "complex floating") for dt in dtypes)

Expand Down
39 changes: 39 additions & 0 deletions dpctl/tests/test_usm_ndarray_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ def test_max_min_axis():
assert dpt.all(m == x[:, 0, 0, :, 0])


def test_max_axis1_axis0():
"""See gh-1455"""
get_queue_or_skip()

x = dpt.reshape(dpt.arange(3 * 4 * 5), (3, 4, 5))

m = dpt.max(x, axis=0)
assert dpt.all(m == x[-1, :, :])

x = dpt.flip(x, axis=2)
m = dpt.max(x, axis=2)
assert dpt.all(m == x[:, :, 0])


def test_reduction_keepdims():
get_queue_or_skip()

Expand Down Expand Up @@ -440,3 +454,28 @@ def test_hypot_complex():
x = dpt.zeros(1, dtype="c8")
with pytest.raises(TypeError):
dpt.reduce_hypot(x)


def test_tree_reduction_axis1_axis0():
"""See gh-1455"""
get_queue_or_skip()

x = dpt.reshape(dpt.arange(3 * 4 * 5, dtype="f4"), (3, 4, 5))

m = dpt.logsumexp(x, axis=0)
tol = dpt.finfo(m.dtype).resolution
assert_allclose(
dpt.asnumpy(m),
np.logaddexp.reduce(dpt.asnumpy(x), axis=0, dtype=m.dtype),
rtol=tol,
atol=tol,
)

x = dpt.flip(x, axis=2)
m = dpt.logsumexp(x, axis=2)
assert_allclose(
dpt.asnumpy(m),
np.logaddexp.reduce(dpt.asnumpy(x), axis=2, dtype=m.dtype),
rtol=tol,
atol=tol,
)

0 comments on commit 02e7714

Please sign in to comment.