From 689e7794038003c160ed039ea9f5751f1258097a Mon Sep 17 00:00:00 2001 From: ajpotts Date: Wed, 14 Aug 2024 10:49:32 -0400 Subject: [PATCH] Closes #3641_remove_translate_np_dtype (#3642) Co-authored-by: Amanda Potts --- PROTO_tests/tests/dtypes_test.py | 20 --------------- arkouda/array_view.py | 21 +++++++--------- arkouda/dtypes.py | 42 +------------------------------- arkouda/pdarrayclass.py | 9 +++---- arkouda/strings.py | 6 ++--- 5 files changed, 15 insertions(+), 83 deletions(-) diff --git a/PROTO_tests/tests/dtypes_test.py b/PROTO_tests/tests/dtypes_test.py index 8ed5e5e2f6..5a153b52be 100644 --- a/PROTO_tests/tests/dtypes_test.py +++ b/PROTO_tests/tests/dtypes_test.py @@ -23,26 +23,6 @@ class TestDTypes: - @pytest.mark.parametrize("dtype", SUPPORTED_NP_DTYPES) - def test_check_np_dtype(self, dtype): - dtypes.check_np_dtype(np.dtype(dtype)) - - @pytest.mark.parametrize("dtype", ["np.str"]) - def test_check_np_dtype_errors(self, dtype): - with pytest.raises(TypeError): - dtypes.check_np_dtype(dtype) - - def test_translate_np_dtype(self): - for b in [np.bool_, bool]: - assert ("bool", 1) == dtypes.translate_np_dtype(np.dtype(b)) - - for s in [np.str_, str]: - assert ("str", 0) == dtypes.translate_np_dtype(np.dtype(s)) - - assert ("int", 8) == dtypes.translate_np_dtype(np.dtype(np.int64)) - assert ("uint", 8) == dtypes.translate_np_dtype(np.dtype(np.uint64)) - assert ("float", 8) == dtypes.translate_np_dtype(np.dtype(np.float64)) - assert ("uint", 1) == dtypes.translate_np_dtype(np.dtype(np.uint8)) def test_resolve_scalar_dtype(self): for b in True, False: diff --git a/arkouda/array_view.py b/arkouda/array_view.py index 5a774f3ea1..0f6bb8d3de 100644 --- a/arkouda/array_view.py +++ b/arkouda/array_view.py @@ -6,7 +6,7 @@ import numpy as np from arkouda.client import generic_msg -from arkouda.dtypes import resolve_scalar_dtype, translate_np_dtype +from arkouda.dtypes import resolve_scalar_dtype from arkouda.numeric import cast as akcast from arkouda.numeric import cumprod, where from arkouda.pdarrayclass import create_pdarray, parse_single_value, pdarray @@ -125,10 +125,9 @@ def __getitem__(self, key): except (RuntimeError, TypeError, ValueError, DeprecationWarning): pass if isinstance(key, pdarray): - kind, _ = translate_np_dtype(key.dtype) - if kind not in ("int", "uint", "bool"): + if key.dtype not in ("int", "uint", "bool"): raise TypeError(f"unsupported pdarray index type {key.dtype}") - if kind == "bool": + if key.dtype == "bool": if key.all(): # every dimension is True, so return this arrayview with shape = [1, self.shape] return self.base.reshape( @@ -141,7 +140,7 @@ def __getitem__(self, key): concatenate([zeros(1, dtype=self.dtype), self.shape]), order=self.order.name ) # Interpret negative key as offset from end of array - key = where(key < 0, akcast(key + self.shape, kind), key) + key = where(key < 0, akcast(key + self.shape, key.dtype), key) # Capture the indices which are still out of bounds out_of_bounds = (key < 0) | (self.shape <= key) if out_of_bounds.any(): @@ -191,10 +190,9 @@ def __getitem__(self, key): elif isinstance(x, pdarray) or isinstance(x, list): raise TypeError(f"Advanced indexing is not yet supported {x} ({type(x)})") # x = array(x) - # kind, _ = translate_np_dtype(x.dtype) - # if kind not in ("bool", "int"): + # if key.dtype not in ("bool", "int"): # raise TypeError("unsupported pdarray index type {}".format(x.dtype)) - # if kind == "bool" and dim != x.size: + # if key.dtype == "bool" and dim != x.size: # raise ValueError("size mismatch {} {}".format(dim, x.size)) # indices.append('pdarray') # indices.append(x.name) @@ -243,17 +241,16 @@ def __setitem__(self, key, value): except (RuntimeError, TypeError, ValueError, DeprecationWarning): pass if isinstance(key, pdarray): - kind, _ = translate_np_dtype(key.dtype) - if kind not in ("int", "uint", "bool"): + if key.dtype not in ("int", "uint", "bool"): raise TypeError(f"unsupported pdarray index type {key.dtype}") - if kind == "bool": + if key.dtype == "bool": if key.all(): # every dimension is True, so fill arrayview with value # if any dimension is False, we don't update anything self.base.fill(value) else: # Interpret negative key as offset from end of array - key = where(key < 0, akcast(key + self.shape, kind), key) + key = where(key < 0, akcast(key + self.shape, key.dtype), key) # Capture the indices which are still out of bounds out_of_bounds = (key < 0) | (self.shape <= key) if out_of_bounds.any(): diff --git a/arkouda/dtypes.py b/arkouda/dtypes.py index 1d949dc682..e3b1ac3a4a 100644 --- a/arkouda/dtypes.py +++ b/arkouda/dtypes.py @@ -3,10 +3,9 @@ import builtins import sys from enum import Enum -from typing import Tuple, Union, cast +from typing import Union, cast import numpy as np -from typeguard import typechecked __all__ = [ "DTypes", @@ -30,8 +29,6 @@ "bigint", "intTypes", "bitType", - "check_np_dtype", - "translate_np_dtype", "resolve_scalar_dtype", "ARKOUDA_SUPPORTED_DTYPES", "bool_scalars", @@ -294,43 +291,6 @@ def isSupportedNumber(num): return isinstance(num, ARKOUDA_SUPPORTED_NUMBERS) -@typechecked -def check_np_dtype(dt: Union[np.dtype, "bigint"]) -> None: - """ - Assert that numpy dtype dt is one of the dtypes supported - by arkouda, otherwise raise TypeError. - - Raises - ------ - TypeError - Raised if the dtype is not in supported dtypes or if - dt is not a np.dtype - """ - - if dtype(dt).name not in DTypes: - raise TypeError(f"Unsupported type: {dt}") - - -@typechecked -def translate_np_dtype(dt) -> Tuple[builtins.str, int]: - """ - Split numpy dtype dt into its kind and byte size, raising - TypeError for unsupported dtypes. - - Raises - ------ - TypeError - Raised if the dtype is not in supported dtypes or if - dt is not a np.dtype - """ - # Assert that dt is one of the arkouda supported dtypes - dt = dtype(dt) - check_np_dtype(dt) - trans = {"i": "int", "f": "float", "b": "bool", "u": "uint", "U": "str", "c": "complex"} - kind = trans[dt.kind] - return kind, dt.itemsize - - def resolve_scalar_dtype(val: object) -> str: """ Try to infer what dtype arkouda_server should treat val as. diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index 0f48086834..ca04777a0c 100644 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -26,7 +26,6 @@ resolve_scalar_dtype, ) from arkouda.dtypes import str_ as akstr_ -from arkouda.dtypes import translate_np_dtype from arkouda.dtypes import uint64 as akuint64 from arkouda.infoclass import information, pretty_print_information from arkouda.logger import getArkoudaLogger @@ -221,8 +220,7 @@ def _parse_index_tuple(key, shape): slices.append((k, k + 1, 1)) elif isinstance(k, pdarray): pdarray_axes.append(dim) - kind, _ = translate_np_dtype(k.dtype) - if kind not in ("bool", "int", "uint"): + if k.dtype not in ("bool", "int", "uint"): raise TypeError(f"unsupported pdarray index type {k.dtype}") # select all indices (needed for mixed slice+pdarray indexing) slices.append((0, shape[dim], 1)) @@ -987,10 +985,9 @@ def __getitem__(self, key): return ret_array if isinstance(key, pdarray) and self.ndim == 1: - kind, _ = translate_np_dtype(key.dtype) - if kind not in ("bool", "int", "uint"): + if key.dtype not in ("bool", "int", "uint"): raise TypeError(f"unsupported pdarray index type {key.dtype}") - if kind == "bool" and self.size != key.size: + if key.dtype == "bool" and self.size != key.size: raise ValueError(f"size mismatch {self.size} {key.size}") repMsg = generic_msg( cmd="[pdarray]", diff --git a/arkouda/strings.py b/arkouda/strings.py index ce7ae44d29..8f6a5c4004 100644 --- a/arkouda/strings.py +++ b/arkouda/strings.py @@ -17,7 +17,6 @@ resolve_scalar_dtype, str_, str_scalars, - translate_np_dtype, ) from arkouda.infoclass import information, list_symbol_table from arkouda.logger import getArkoudaLogger @@ -320,10 +319,9 @@ def __getitem__(self, key): ) return Strings.from_return_msg(repMsg) elif isinstance(key, pdarray): - kind, _ = translate_np_dtype(key.dtype) - if kind not in ("bool", "int", "uint"): + if key.dtype not in ("bool", "int", "uint"): raise TypeError(f"unsupported pdarray index type {key.dtype}") - if kind == "bool" and self.size != key.size: + if key.dtype == "bool" and self.size != key.size: raise ValueError(f"size mismatch {self.size} {key.size}") repMsg = generic_msg( cmd="segmentedIndex",