Skip to content

Commit

Permalink
Changes mean reduction to use output data type as sum accumulation type
Browse files Browse the repository at this point in the history
Mean in-place division now uses the real type for the denominator
  • Loading branch information
ndgrigorian committed Nov 8, 2023
1 parent 69fdaa5 commit 19ffc5f
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions dpctl/tensor/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import dpctl.tensor._tensor_impl as ti
import dpctl.tensor._tensor_reductions_impl as tri

from ._reduction import _default_reduction_dtype


def _var_impl(x, axis, correction, keepdims):
nd = x.ndim
Expand Down Expand Up @@ -233,22 +231,25 @@ def mean(x, axis=None, keepdims=False):
host_tasks_list.append(ht_e1)
s_e.append(r_e)
else:
tmp_dt = _default_reduction_dtype(inp_dt, q)
tmp = dpt.empty(
res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q
arr2.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
)
ht_e_tmp, r_e = tri._sum_over_axis(
src=arr2, trailing_dims_to_reduce=sum_nd, dst=tmp, sycl_queue=q
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
src=arr2, dst=tmp, sycl_queue=q
)
host_tasks_list.append(ht_e_tmp)
host_tasks_list.append(ht_e_cpy)
res = dpt.empty(
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
)
ht_e1, c_e = ti._copy_usm_ndarray_into_usm_ndarray(
src=tmp, dst=res, sycl_queue=q, depends=[r_e]
ht_e_red, r_e = tri._sum_over_axis(
src=tmp,
trailing_dims_to_reduce=sum_nd,
dst=res,
sycl_queue=q,
depends=[cpy_e],
)
host_tasks_list.append(ht_e1)
s_e.append(c_e)
host_tasks_list.append(ht_e_red)
s_e.append(r_e)

if keepdims:
res_shape = res_shape + (1,) * sum_nd
Expand All @@ -257,8 +258,9 @@ def mean(x, axis=None, keepdims=False):

res_shape = res.shape
# in-place divide
den_dt = dpt.finfo(res_dt).dtype if res_dt.kind == "c" else res_dt
nelems_arr = dpt.asarray(
nelems, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
nelems, dtype=den_dt, usm_type=res_usm_type, sycl_queue=q
)
if nelems_arr.shape != res_shape:
nelems_arr = dpt.broadcast_to(nelems_arr, res_shape)
Expand Down

0 comments on commit 19ffc5f

Please sign in to comment.