From 9253931f5b53884f04d7269062a39951546e1777 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 8 Nov 2023 02:27:44 -0800 Subject: [PATCH] Changes mean reduction to use output data type as sum accumulation type Mean in-place division now uses the real type for the denominator --- dpctl/tensor/_statistical_functions.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/dpctl/tensor/_statistical_functions.py b/dpctl/tensor/_statistical_functions.py index 6508079e98..54d748d2d2 100644 --- a/dpctl/tensor/_statistical_functions.py +++ b/dpctl/tensor/_statistical_functions.py @@ -22,8 +22,6 @@ import dpctl.tensor._tensor_impl as ti import dpctl.tensor._tensor_reductions_impl as tri -from ._reduction import _default_reduction_dtype - def _var_impl(x, axis, correction, keepdims): nd = x.ndim @@ -233,22 +231,25 @@ def mean(x, axis=None, keepdims=False): host_tasks_list.append(ht_e1) s_e.append(r_e) else: - tmp_dt = _default_reduction_dtype(inp_dt, q) tmp = dpt.empty( - res_shape, dtype=tmp_dt, usm_type=res_usm_type, sycl_queue=q + arr2.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q ) - ht_e_tmp, r_e = tri._sum_over_axis( - src=arr2, trailing_dims_to_reduce=sum_nd, dst=tmp, sycl_queue=q + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr2, dst=tmp, sycl_queue=q ) - host_tasks_list.append(ht_e_tmp) + host_tasks_list.append(ht_e_cpy) res = dpt.empty( res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q ) - ht_e1, c_e = ti._copy_usm_ndarray_into_usm_ndarray( - src=tmp, dst=res, sycl_queue=q, depends=[r_e] + ht_e_red, r_e = tri._sum_over_axis( + src=tmp, + trailing_dims_to_reduce=sum_nd, + dst=res, + sycl_queue=q, + depends=[cpy_e], ) - host_tasks_list.append(ht_e1) - s_e.append(c_e) + host_tasks_list.append(ht_e_red) + s_e.append(r_e) if keepdims: res_shape = res_shape + (1,) * sum_nd @@ -257,8 +258,9 @@ def mean(x, axis=None, keepdims=False): res_shape = res.shape # in-place divide + den_dt = dpt.finfo(res_dt).dtype if res_dt.kind == "c" else res_dt nelems_arr = dpt.asarray( - nelems, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + nelems, dtype=den_dt, usm_type=res_usm_type, sycl_queue=q ) if nelems_arr.shape != res_shape: nelems_arr = dpt.broadcast_to(nelems_arr, res_shape)