Skip to content

Commit

Permalink
Closes #3641_remove_translate_np_dtype (#3642)
Browse files Browse the repository at this point in the history
Co-authored-by: Amanda Potts <ajpotts@users.noreply.github.com>
  • Loading branch information
ajpotts and ajpotts committed Aug 14, 2024
1 parent bc34f65 commit 689e779
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 83 deletions.
20 changes: 0 additions & 20 deletions PROTO_tests/tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,6 @@


class TestDTypes:
@pytest.mark.parametrize("dtype", SUPPORTED_NP_DTYPES)
def test_check_np_dtype(self, dtype):
dtypes.check_np_dtype(np.dtype(dtype))

@pytest.mark.parametrize("dtype", ["np.str"])
def test_check_np_dtype_errors(self, dtype):
with pytest.raises(TypeError):
dtypes.check_np_dtype(dtype)

def test_translate_np_dtype(self):
for b in [np.bool_, bool]:
assert ("bool", 1) == dtypes.translate_np_dtype(np.dtype(b))

for s in [np.str_, str]:
assert ("str", 0) == dtypes.translate_np_dtype(np.dtype(s))

assert ("int", 8) == dtypes.translate_np_dtype(np.dtype(np.int64))
assert ("uint", 8) == dtypes.translate_np_dtype(np.dtype(np.uint64))
assert ("float", 8) == dtypes.translate_np_dtype(np.dtype(np.float64))
assert ("uint", 1) == dtypes.translate_np_dtype(np.dtype(np.uint8))

def test_resolve_scalar_dtype(self):
for b in True, False:
Expand Down
21 changes: 9 additions & 12 deletions arkouda/array_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from arkouda.client import generic_msg
from arkouda.dtypes import resolve_scalar_dtype, translate_np_dtype
from arkouda.dtypes import resolve_scalar_dtype
from arkouda.numeric import cast as akcast
from arkouda.numeric import cumprod, where
from arkouda.pdarrayclass import create_pdarray, parse_single_value, pdarray
Expand Down Expand Up @@ -125,10 +125,9 @@ def __getitem__(self, key):
except (RuntimeError, TypeError, ValueError, DeprecationWarning):
pass
if isinstance(key, pdarray):
kind, _ = translate_np_dtype(key.dtype)
if kind not in ("int", "uint", "bool"):
if key.dtype not in ("int", "uint", "bool"):
raise TypeError(f"unsupported pdarray index type {key.dtype}")
if kind == "bool":
if key.dtype == "bool":
if key.all():
# every dimension is True, so return this arrayview with shape = [1, self.shape]
return self.base.reshape(
Expand All @@ -141,7 +140,7 @@ def __getitem__(self, key):
concatenate([zeros(1, dtype=self.dtype), self.shape]), order=self.order.name
)
# Interpret negative key as offset from end of array
key = where(key < 0, akcast(key + self.shape, kind), key)
key = where(key < 0, akcast(key + self.shape, key.dtype), key)
# Capture the indices which are still out of bounds
out_of_bounds = (key < 0) | (self.shape <= key)
if out_of_bounds.any():
Expand Down Expand Up @@ -191,10 +190,9 @@ def __getitem__(self, key):
elif isinstance(x, pdarray) or isinstance(x, list):
raise TypeError(f"Advanced indexing is not yet supported {x} ({type(x)})")
# x = array(x)
# kind, _ = translate_np_dtype(x.dtype)
# if kind not in ("bool", "int"):
# if key.dtype not in ("bool", "int"):
# raise TypeError("unsupported pdarray index type {}".format(x.dtype))
# if kind == "bool" and dim != x.size:
# if key.dtype == "bool" and dim != x.size:
# raise ValueError("size mismatch {} {}".format(dim, x.size))
# indices.append('pdarray')
# indices.append(x.name)
Expand Down Expand Up @@ -243,17 +241,16 @@ def __setitem__(self, key, value):
except (RuntimeError, TypeError, ValueError, DeprecationWarning):
pass
if isinstance(key, pdarray):
kind, _ = translate_np_dtype(key.dtype)
if kind not in ("int", "uint", "bool"):
if key.dtype not in ("int", "uint", "bool"):
raise TypeError(f"unsupported pdarray index type {key.dtype}")
if kind == "bool":
if key.dtype == "bool":
if key.all():
# every dimension is True, so fill arrayview with value
# if any dimension is False, we don't update anything
self.base.fill(value)
else:
# Interpret negative key as offset from end of array
key = where(key < 0, akcast(key + self.shape, kind), key)
key = where(key < 0, akcast(key + self.shape, key.dtype), key)
# Capture the indices which are still out of bounds
out_of_bounds = (key < 0) | (self.shape <= key)
if out_of_bounds.any():
Expand Down
42 changes: 1 addition & 41 deletions arkouda/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
import builtins
import sys
from enum import Enum
from typing import Tuple, Union, cast
from typing import Union, cast

import numpy as np
from typeguard import typechecked

__all__ = [
"DTypes",
Expand All @@ -30,8 +29,6 @@
"bigint",
"intTypes",
"bitType",
"check_np_dtype",
"translate_np_dtype",
"resolve_scalar_dtype",
"ARKOUDA_SUPPORTED_DTYPES",
"bool_scalars",
Expand Down Expand Up @@ -294,43 +291,6 @@ def isSupportedNumber(num):
return isinstance(num, ARKOUDA_SUPPORTED_NUMBERS)


@typechecked
def check_np_dtype(dt: Union[np.dtype, "bigint"]) -> None:
"""
Assert that numpy dtype dt is one of the dtypes supported
by arkouda, otherwise raise TypeError.
Raises
------
TypeError
Raised if the dtype is not in supported dtypes or if
dt is not a np.dtype
"""

if dtype(dt).name not in DTypes:
raise TypeError(f"Unsupported type: {dt}")


@typechecked
def translate_np_dtype(dt) -> Tuple[builtins.str, int]:
"""
Split numpy dtype dt into its kind and byte size, raising
TypeError for unsupported dtypes.
Raises
------
TypeError
Raised if the dtype is not in supported dtypes or if
dt is not a np.dtype
"""
# Assert that dt is one of the arkouda supported dtypes
dt = dtype(dt)
check_np_dtype(dt)
trans = {"i": "int", "f": "float", "b": "bool", "u": "uint", "U": "str", "c": "complex"}
kind = trans[dt.kind]
return kind, dt.itemsize


def resolve_scalar_dtype(val: object) -> str:
"""
Try to infer what dtype arkouda_server should treat val as.
Expand Down
9 changes: 3 additions & 6 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
resolve_scalar_dtype,
)
from arkouda.dtypes import str_ as akstr_
from arkouda.dtypes import translate_np_dtype
from arkouda.dtypes import uint64 as akuint64
from arkouda.infoclass import information, pretty_print_information
from arkouda.logger import getArkoudaLogger
Expand Down Expand Up @@ -221,8 +220,7 @@ def _parse_index_tuple(key, shape):
slices.append((k, k + 1, 1))
elif isinstance(k, pdarray):
pdarray_axes.append(dim)
kind, _ = translate_np_dtype(k.dtype)
if kind not in ("bool", "int", "uint"):
if k.dtype not in ("bool", "int", "uint"):
raise TypeError(f"unsupported pdarray index type {k.dtype}")
# select all indices (needed for mixed slice+pdarray indexing)
slices.append((0, shape[dim], 1))
Expand Down Expand Up @@ -987,10 +985,9 @@ def __getitem__(self, key):
return ret_array

if isinstance(key, pdarray) and self.ndim == 1:
kind, _ = translate_np_dtype(key.dtype)
if kind not in ("bool", "int", "uint"):
if key.dtype not in ("bool", "int", "uint"):
raise TypeError(f"unsupported pdarray index type {key.dtype}")
if kind == "bool" and self.size != key.size:
if key.dtype == "bool" and self.size != key.size:
raise ValueError(f"size mismatch {self.size} {key.size}")
repMsg = generic_msg(
cmd="[pdarray]",
Expand Down
6 changes: 2 additions & 4 deletions arkouda/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
resolve_scalar_dtype,
str_,
str_scalars,
translate_np_dtype,
)
from arkouda.infoclass import information, list_symbol_table
from arkouda.logger import getArkoudaLogger
Expand Down Expand Up @@ -320,10 +319,9 @@ def __getitem__(self, key):
)
return Strings.from_return_msg(repMsg)
elif isinstance(key, pdarray):
kind, _ = translate_np_dtype(key.dtype)
if kind not in ("bool", "int", "uint"):
if key.dtype not in ("bool", "int", "uint"):
raise TypeError(f"unsupported pdarray index type {key.dtype}")
if kind == "bool" and self.size != key.size:
if key.dtype == "bool" and self.size != key.size:
raise ValueError(f"size mismatch {self.size} {key.size}")
repMsg = generic_msg(
cmd="segmentedIndex",
Expand Down

0 comments on commit 689e779

Please sign in to comment.