Skip to content

Commit

Permalink
Adds element-wise functions angle and reciprocal (#1474)
Browse files Browse the repository at this point in the history
* Implements elementwise reciprocal

* Fixes typo in kernels/elementwise_functions/proj.hpp

* Implements elementwise angle

* UnaryElementwiseFunc class now takes an acceptance function

This change was made to mirror promotion behavior of divide in reciprocal

Adds getter method for the acceptance function

Adds tests for reciprocal

* Small bugfix in _zero_like

_zero_like did not have logic accounting for 0D arrays, so `x.imag` failed for 0D x

* _zero_like now allocates using the same sycl_queue

This prevents unexpected behavior when calling `imag`
i.e., for x with a real-valued data type
`dpctl.tensor.atan2(x.imag, x.real)` would not work prior to this change

* Fixes bugs in `real` and `imag` properties

The logic in these properties did not work for float16 data types, returning None instead of `self` or an array of zeros

* Adds tests for angle

* Adds tests for fixes to `real`/`imag` properties

* Adds test that `real`, `imag` use the same queue

* Correction to rsqrt docstring

* Change acceptance function names per feedback

`_acceptance_fn_default1` and `_acceptance_fn_default2` are now
`_acceptance_fn_default_unary` and `_acceptance_fn_default_binary`
  • Loading branch information
ndgrigorian authored Nov 30, 2023
1 parent 5ec9fd5 commit b3ba5ac
Show file tree
Hide file tree
Showing 18 changed files with 1,089 additions and 19 deletions.
2 changes: 2 additions & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ set(_elementwise_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acos.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acosh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/add.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/angle.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asin.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asinh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan.cpp
Expand Down Expand Up @@ -87,6 +88,7 @@ set(_elementwise_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/proj.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/real.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/reciprocal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/remainder.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/round.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/rsqrt.cpp
Expand Down
4 changes: 4 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
acos,
acosh,
add,
angle,
asin,
asinh,
atan,
Expand Down Expand Up @@ -153,6 +154,7 @@
pow,
proj,
real,
reciprocal,
remainder,
round,
rsqrt,
Expand Down Expand Up @@ -342,4 +344,6 @@
"var",
"__array_api_version__",
"__array_namespace_info__",
"reciprocal",
"angle",
]
52 changes: 48 additions & 4 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
from ._type_utils import (
_acceptance_fn_default,
_acceptance_fn_default_binary,
_acceptance_fn_default_unary,
_all_data_types,
_find_buf_dtype,
_find_buf_dtype2,
Expand Down Expand Up @@ -62,17 +63,39 @@ class UnaryElementwiseFunc:
computational tasks complete execution, while the second event
corresponds to computational tasks associated with function
evaluation.
acceptance_fn (callable, optional):
Function to influence type promotion behavior of this unary
function. The function takes 4 arguments:
arg_dtype - Data type of the first argument
buf_dtype - Data type the argument would be cast to
res_dtype - Data type of the output array with function values
sycl_dev - The :class:`dpctl.SyclDevice` where the function
evaluation is carried out.
The function is invoked when the argument of the unary function
requires casting, e.g. the argument of `dpctl.tensor.log` is an
array with integral data type.
docs (str):
Documentation string for the unary function.
"""

def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
def __init__(
self,
name,
result_type_resolver_fn,
unary_dp_impl_fn,
docs,
acceptance_fn=None,
):
self.__name__ = "UnaryElementwiseFunc"
self.name_ = name
self.result_type_resolver_fn_ = result_type_resolver_fn
self.types_ = None
self.unary_fn_ = unary_dp_impl_fn
self.__doc__ = docs
if callable(acceptance_fn):
self.acceptance_fn_ = acceptance_fn
else:
self.acceptance_fn_ = _acceptance_fn_default_unary

def __str__(self):
return f"<{self.__name__} '{self.name_}'>"
Expand All @@ -93,6 +116,24 @@ def get_type_result_resolver_function(self):
"""
return self.result_type_resolver_fn_

def get_type_promotion_path_acceptance_function(self):
"""Returns the acceptance function for this
elementwise binary function.
Acceptance function influences the type promotion
behavior of this unary function.
The function takes 4 arguments:
arg_dtype - Data type of the first argument
buf_dtype - Data type the argument would be cast to
res_dtype - Data type of the output array with function values
sycl_dev - The :class:`dpctl.SyclDevice` where the function
evaluation is carried out.
The function is invoked when the argument of the unary function
requires casting, e.g. the argument of `dpctl.tensor.log` is an
array with integral data type.
"""
return self.acceptance_fn_

@property
def types(self):
"""Returns information about types supported by
Expand Down Expand Up @@ -122,7 +163,10 @@ def __call__(self, x, out=None, order="K"):
if order not in ["C", "F", "K", "A"]:
order = "K"
buf_dt, res_dt = _find_buf_dtype(
x.dtype, self.result_type_resolver_fn_, x.sycl_device
x.dtype,
self.result_type_resolver_fn_,
x.sycl_device,
acceptance_fn=self.acceptance_fn_,
)
if res_dt is None:
raise TypeError(
Expand Down Expand Up @@ -482,7 +526,7 @@ def __init__(
if callable(acceptance_fn):
self.acceptance_fn_ = acceptance_fn
else:
self.acceptance_fn_ = _acceptance_fn_default
self.acceptance_fn_ = _acceptance_fn_default_binary

def __str__(self):
return f"<{self.__name__} '{self.name_}'>"
Expand Down
66 changes: 64 additions & 2 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import dpctl.tensor._tensor_elementwise_impl as ti

from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc
from ._type_utils import _acceptance_fn_divide
from ._type_utils import _acceptance_fn_divide, _acceptance_fn_reciprocal

# U01: ==== ABS (x)
_abs_docstring_ = """
Expand Down Expand Up @@ -1880,10 +1880,72 @@
Returns:
usm_narray:
An array containing the element-wise reciprocal square-root.
The data type of the returned array is determined by
The returned array has a floating-point data type determined by
the Type Promotion Rules.
"""

rsqrt = UnaryElementwiseFunc(
"rsqrt", ti._rsqrt_result_type, ti._rsqrt, _rsqrt_docstring_
)


# U42: ==== RECIPROCAL (x)
_reciprocal_docstring = """
reciprocal(x, out=None, order='K')
Computes the reciprocal of each element `x_i` for input array `x`.
Args:
x (usm_ndarray):
Input array, expected to have a real-valued floating-point data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise reciprocals.
The returned array has a floating-point data type determined
by the Type Promotion Rules.
"""

reciprocal = UnaryElementwiseFunc(
"reciprocal",
ti._reciprocal_result_type,
ti._reciprocal,
_reciprocal_docstring,
acceptance_fn=_acceptance_fn_reciprocal,
)


# U43: ==== ANGLE (x)
_angle_docstring = """
angle(x, out=None, order='K')
Computes the phase angle (also called the argument) of each element `x_i` for
input array `x`.
Args:
x (usm_ndarray):
Input array, expected to have a complex-valued floating-point data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise phase angles.
The returned array has a floating-point data type determined
by the Type Promotion Rules.
"""

angle = UnaryElementwiseFunc(
"angle",
ti._angle_result_type,
ti._angle,
_angle_docstring,
)
34 changes: 30 additions & 4 deletions dpctl/tensor/_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,27 @@ def _to_device_supported_dtype(dt, dev):
return dt


def _find_buf_dtype(arg_dtype, query_fn, sycl_dev):
def _acceptance_fn_default_unary(arg_dtype, ret_buf_dt, res_dt, sycl_dev):
return True


def _acceptance_fn_reciprocal(arg_dtype, buf_dt, res_dt, sycl_dev):
# if the kind of result is different from
# the kind of input, use the default data
# we use default dtype for the resulting kind.
# This guarantees alignment of reciprocal and
# divide output types.
if buf_dt.kind != arg_dtype.kind:
default_dt = _get_device_default_dtype(res_dt.kind, sycl_dev)
if res_dt == default_dt:
return True
else:
return False
else:
return True


def _find_buf_dtype(arg_dtype, query_fn, sycl_dev, acceptance_fn):
res_dt = query_fn(arg_dtype)
if res_dt:
return None, res_dt
Expand All @@ -144,7 +164,11 @@ def _find_buf_dtype(arg_dtype, query_fn, sycl_dev):
if _can_cast(arg_dtype, buf_dt, _fp16, _fp64):
res_dt = query_fn(buf_dt)
if res_dt:
return buf_dt, res_dt
acceptable = acceptance_fn(arg_dtype, buf_dt, res_dt, sycl_dev)
if acceptable:
return buf_dt, res_dt
else:
continue

return None, None

Expand All @@ -163,7 +187,7 @@ def _get_device_default_dtype(dt_kind, sycl_dev):
raise RuntimeError


def _acceptance_fn_default(
def _acceptance_fn_default_binary(
arg1_dtype, arg2_dtype, ret_buf1_dt, ret_buf2_dt, res_dt, sycl_dev
):
return True
Expand Down Expand Up @@ -230,6 +254,8 @@ def _find_buf_dtype2(arg1_dtype, arg2_dtype, query_fn, sycl_dev, acceptance_fn):
"_find_buf_dtype",
"_find_buf_dtype2",
"_to_device_supported_dtype",
"_acceptance_fn_default",
"_acceptance_fn_default_unary",
"_acceptance_fn_reciprocal",
"_acceptance_fn_default_binary",
"_acceptance_fn_divide",
]
15 changes: 9 additions & 6 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,8 @@ cdef class usm_ndarray:
""" Returns real component for arrays with complex data-types
and returns itself for all other data-types.
"""
if (self.typenum_ < UAR_CFLOAT):
# explicitly check for UAR_HALF, which is greater than UAR_CFLOAT
if (self.typenum_ < UAR_CFLOAT or self.typenum_ == UAR_HALF):
# elements are real
return self
if (self.typenum_ < UAR_TYPE_SENTINEL):
Expand All @@ -698,7 +699,8 @@ cdef class usm_ndarray:
""" Returns imaginary component for arrays with complex data-types
and returns zero array for all other data-types.
"""
if (self.typenum_ < UAR_CFLOAT):
# explicitly check for UAR_HALF, which is greater than UAR_CFLOAT
if (self.typenum_ < UAR_CFLOAT or self.typenum_ == UAR_HALF):
# elements are real
return _zero_like(self)
if (self.typenum_ < UAR_TYPE_SENTINEL):
Expand Down Expand Up @@ -1306,14 +1308,15 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary):

cdef usm_ndarray _zero_like(usm_ndarray ary):
"""
Make C-contiguous array of zero elements with same shape
and type as ary.
Make C-contiguous array of zero elements with same shape,
type, device, and sycl_queue as ary.
"""
cdef dt = _make_typestr(ary.typenum_)
cdef usm_ndarray r = usm_ndarray(
_make_int_tuple(ary.nd_, ary.shape_),
_make_int_tuple(ary.nd_, ary.shape_) if ary.nd_ > 0 else tuple(),
dtype=dt,
buffer=ary.base_.get_usm_type()
buffer=ary.base_.get_usm_type(),
buffer_ctor_kwargs={"queue": ary.get_sycl_queue()},
)
r.base_.memset()
return r
Expand Down
Loading

0 comments on commit b3ba5ac

Please sign in to comment.