Skip to content

Fix for gh-1468 in arithmetic reduction when type promotion is needed #1470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 50 additions & 20 deletions dpctl/tensor/_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _reduction_over_axis(
dpt.full(
res_shape,
_identity,
dtype=_default_reduction_type_fn(inp_dt, q),
dtype=dtype,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, the call to astype is no longer necessary.

It also means that when logsumexp or reduce_hypot reduce over an empty axis or array (e.g., dpt.logsumexp(dpt.ones((1, 0, 1), dtype="i4"), axis=1, dtype="i4")) you get OverflowError instead of silently casting the identity to the output type.

For now, the astype can be removed. I've experimented with removing this branch in #1465 too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and it should not be dtype, it should be res_dt.

usm_type=res_usm_type,
sycl_queue=q,
),
Expand All @@ -142,21 +142,51 @@ def _reduction_over_axis(
"Automatically determined reduction data type does not "
"have direct implementation"
)
tmp_dt = _default_reduction_type_fn(inp_dt, q)
tmp = dpt.empty(
res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
)
ht_e_tmp, r_e = _reduction_fn(
src=arr2, trailing_dims_to_reduce=red_nd, dst=tmp, sycl_queue=q
)
host_tasks_list.append(ht_e_tmp)
res = dpt.empty(
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
)
ht_e, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=tmp, dst=res, sycl_queue=q, depends=[r_e]
)
host_tasks_list.append(ht_e)
if _dtype_supported(res_dt, res_dt, res_usm_type, q):
tmp = dpt.empty(
arr2.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
)
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
src=arr2, dst=tmp, sycl_queue=q
)
host_tasks_list.append(ht_e_cpy)
res = dpt.empty(
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
)
ht_e_red, _ = _reduction_fn(
src=tmp,
trailing_dims_to_reduce=red_nd,
dst=res,
sycl_queue=q,
depends=[cpy_e],
)
host_tasks_list.append(ht_e_red)
else:
buf_dt = _default_reduction_type_fn(inp_dt, q)
tmp = dpt.empty(
arr2.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
)
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
src=arr2, dst=tmp, sycl_queue=q
)
tmp_res = dpt.empty(
res_shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
)
host_tasks_list.append(ht_e_cpy)
res = dpt.empty(
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
)
ht_e_red, r_e = _reduction_fn(
src=tmp,
trailing_dims_to_reduce=red_nd,
dst=tmp_res,
sycl_queue=q,
depends=[cpy_e],
)
ht_e_cpy2, _ = ti._copy_usm_ndarray_into_usm_ndarray(
src=tmp_res, dst=res, sycl_queue=q, depends=[r_e]
)
host_tasks_list.append(ht_e_cpy2)

if keepdims:
res_shape = res_shape + (1,) * red_nd
Expand Down Expand Up @@ -445,7 +475,7 @@ def _comparison_over_axis(x, axis, keepdims, _reduction_fn):


def max(x, axis=None, keepdims=False):
"""max(x, axis=None, dtype=None, keepdims=False)
"""max(x, axis=None, keepdims=False)

Calculates the maximum value of the input array `x`.

Expand Down Expand Up @@ -473,7 +503,7 @@ def max(x, axis=None, keepdims=False):


def min(x, axis=None, keepdims=False):
"""min(x, axis=None, dtype=None, keepdims=False)
"""min(x, axis=None, keepdims=False)

Calculates the minimum value of the input array `x`.

Expand Down Expand Up @@ -550,7 +580,7 @@ def _search_over_axis(x, axis, keepdims, _reduction_fn):


def argmax(x, axis=None, keepdims=False):
"""argmax(x, axis=None, dtype=None, keepdims=False)
"""argmax(x, axis=None, keepdims=False)

Returns the indices of the maximum values of the input array `x` along a
specified axis.
Expand Down Expand Up @@ -582,7 +612,7 @@ def argmax(x, axis=None, keepdims=False):


def argmin(x, axis=None, keepdims=False):
"""argmin(x, axis=None, dtype=None, keepdims=False)
"""argmin(x, axis=None, keepdims=False)

Returns the indices of the minimum values of the input array `x` along a
specified axis.
Expand Down
9 changes: 9 additions & 0 deletions dpctl/tests/test_tensor_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,12 @@ def test_prod_arg_out_dtype_matrix(arg_dtype, out_dtype):
assert isinstance(r, dpt.usm_ndarray)
assert r.dtype == dpt.dtype(out_dtype)
assert dpt.all(r == 1)


def test_gh_1468():
"See https://github.com/IntelPython/dpctl/issues/1468"
get_queue_or_skip()

a = dpt.full((2, 3, 4), 123456789, dtype=dpt.int32)
t = dpt.sum(a, dtype="f4")
assert t > 0