Skip to content

Commit

Permalink
Merge pull request #1878 from IntelPython/improve-full-error-for-inva…
Browse files Browse the repository at this point in the history
…lid-scalar-type

Improve `dpctl.tensor.full` error for invalid `fill_value`
  • Loading branch information
ndgrigorian authored Oct 24, 2024
2 parents 286afae + 7095358 commit 9b83bef
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
* Fix for `tensor.result_type` when all inputs are Python built-in scalars [gh-1877](https://github.com/IntelPython/dpctl/pull/1877)

* Improved error in constructors `tensor.full` and `tensor.full_like` when provided a non-numeric fill value [gh-1878](https://github.com/IntelPython/dpctl/pull/1878)

### Maintenance

* Update black version used in Python code style workflow [gh-1828](https://github.com/IntelPython/dpctl/pull/1828)
Expand Down
18 changes: 18 additions & 0 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

import operator
from numbers import Number

import numpy as np

Expand Down Expand Up @@ -1037,6 +1038,19 @@ def _cast_fill_val(fill_val, dt):
return fill_val


def _validate_fill_value(fill_val):
"""
Validates that `fill_val` is a numeric or boolean scalar.
"""
# TODO: verify if `np.True_` and `np.False_` should be instances of
# Number in NumPy, like other NumPy scalars and like Python bools
# check for `np.bool_` separately as NumPy<2 has no `np.bool`
if not isinstance(fill_val, Number) and not isinstance(fill_val, np.bool_):
raise TypeError(
f"array cannot be filled with scalar of type {type(fill_val)}"
)


def full(
shape,
fill_value,
Expand Down Expand Up @@ -1110,6 +1124,8 @@ def full(
sycl_queue=sycl_queue,
)
return dpt.copy(dpt.broadcast_to(X, shape), order=order)
else:
_validate_fill_value(fill_value)

sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
usm_type = usm_type if usm_type is not None else "device"
Expand Down Expand Up @@ -1480,6 +1496,8 @@ def full_like(
)
_manager.add_event_pair(hev, copy_ev)
return res
else:
_validate_fill_value(fill_value)

dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)
Expand Down
11 changes: 11 additions & 0 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2621,3 +2621,14 @@ def test_setitem_from_numpy_contig():

expected = dpt.reshape(dpt.arange(-10, 10, dtype=fp_dt), (4, 5))
assert dpt.all(dpt.flip(Xdpt, axis=-1) == expected)


def test_full_functions_raise_type_error():
get_queue_or_skip()

with pytest.raises(TypeError):
dpt.full(1, "0")

x = dpt.ones(1, dtype="i4")
with pytest.raises(TypeError):
dpt.full_like(x, "0")

0 comments on commit 9b83bef

Please sign in to comment.