From 70953582729c033099825ba5e66f25f0d12dee09 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 23 Oct 2024 16:01:40 -0700 Subject: [PATCH] Factor out checks for fill value scalar type into `_validate_fill_value` function --- dpctl/tensor/_ctors.py | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/dpctl/tensor/_ctors.py b/dpctl/tensor/_ctors.py index 1e6573a792..37236cad6b 100644 --- a/dpctl/tensor/_ctors.py +++ b/dpctl/tensor/_ctors.py @@ -1038,6 +1038,19 @@ def _cast_fill_val(fill_val, dt): return fill_val +def _validate_fill_value(fill_val): + """ + Validates that `fill_val` is a numeric or boolean scalar. + """ + # TODO: verify if `np.True_` and `np.False_` should be instances of + # Number in NumPy, like other NumPy scalars and like Python bools + # check for `np.bool_` separately as NumPy<2 has no `np.bool` + if not isinstance(fill_val, Number) and not isinstance(fill_val, np.bool_): + raise TypeError( + f"array cannot be filled with scalar of type {type(fill_val)}" + ) + + def full( shape, fill_value, @@ -1111,16 +1124,8 @@ def full( sycl_queue=sycl_queue, ) return dpt.copy(dpt.broadcast_to(X, shape), order=order) - # TODO: verify if `np.True_` and `np.False_` should be instances of - # Number in NumPy, like other NumPy scalars and like Python bools - # check for `np.bool_` separately as NumPy<2 has no `np.bool` - elif not isinstance(fill_value, Number) and not isinstance( - fill_value, np.bool_ - ): - raise TypeError( - "`full` array cannot be constructed with value of type " - f"{type(fill_value)}" - ) + else: + _validate_fill_value(fill_value) sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device) usm_type = usm_type if usm_type is not None else "device" @@ -1491,16 +1496,8 @@ def full_like( ) _manager.add_event_pair(hev, copy_ev) return res - # TODO: verify if `np.True_` and `np.False_` should be instances of - # Number in NumPy, like other NumPy scalars and like Python bools - # check for `np.bool_` separately as NumPy<2 has no `np.bool` - elif not isinstance(fill_value, Number) and not isinstance( - fill_value, np.bool_ - ): - raise TypeError( - "`full` array cannot be constructed with value of type " - f"{type(fill_value)}" - ) + else: + _validate_fill_value(fill_value) dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value)) res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)