Skip to content

Commit

Permalink
Factor out checks for fill value scalar type into `_validate_fill_val…
Browse files Browse the repository at this point in the history
…ue` function
  • Loading branch information
ndgrigorian committed Oct 23, 2024
1 parent 320e5e4 commit 7095358
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,19 @@ def _cast_fill_val(fill_val, dt):
return fill_val


def _validate_fill_value(fill_val):
"""
Validates that `fill_val` is a numeric or boolean scalar.
"""
# TODO: verify if `np.True_` and `np.False_` should be instances of
# Number in NumPy, like other NumPy scalars and like Python bools
# check for `np.bool_` separately as NumPy<2 has no `np.bool`
if not isinstance(fill_val, Number) and not isinstance(fill_val, np.bool_):
raise TypeError(
f"array cannot be filled with scalar of type {type(fill_val)}"
)


def full(
shape,
fill_value,
Expand Down Expand Up @@ -1111,16 +1124,8 @@ def full(
sycl_queue=sycl_queue,
)
return dpt.copy(dpt.broadcast_to(X, shape), order=order)
# TODO: verify if `np.True_` and `np.False_` should be instances of
# Number in NumPy, like other NumPy scalars and like Python bools
# check for `np.bool_` separately as NumPy<2 has no `np.bool`
elif not isinstance(fill_value, Number) and not isinstance(
fill_value, np.bool_
):
raise TypeError(
"`full` array cannot be constructed with value of type "
f"{type(fill_value)}"
)
else:
_validate_fill_value(fill_value)

sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
usm_type = usm_type if usm_type is not None else "device"
Expand Down Expand Up @@ -1491,16 +1496,8 @@ def full_like(
)
_manager.add_event_pair(hev, copy_ev)
return res
# TODO: verify if `np.True_` and `np.False_` should be instances of
# Number in NumPy, like other NumPy scalars and like Python bools
# check for `np.bool_` separately as NumPy<2 has no `np.bool`
elif not isinstance(fill_value, Number) and not isinstance(
fill_value, np.bool_
):
raise TypeError(
"`full` array cannot be constructed with value of type "
f"{type(fill_value)}"
)
else:
_validate_fill_value(fill_value)

dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)
Expand Down

0 comments on commit 7095358

Please sign in to comment.