Skip to content

Commit

Permalink
Part of #3708: array_api to call functions from arkouda.pdarray_creation
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Sep 10, 2024
1 parent 4db3469 commit ac22c11
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 140 deletions.
79 changes: 16 additions & 63 deletions arkouda/array_api/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast

from arkouda.client import generic_msg
import numpy as np
from arkouda.pdarrayclass import create_pdarray, pdarray, _to_pdarray
from arkouda.pdarraycreation import scalar_array

from arkouda.numpy.dtypes import dtype as akdtype
from arkouda.numpy.dtypes import resolve_scalar_dtype
from arkouda.pdarrayclass import _to_pdarray, pdarray

if TYPE_CHECKING:
from ._typing import (
Expand All @@ -17,6 +16,7 @@
NestedSequence,
SupportsBufferProtocol,
)

import arkouda as ak


Expand Down Expand Up @@ -83,9 +83,7 @@ def asarray(
elif isinstance(obj, np.ndarray):
return Array._new(_to_pdarray(obj, dt=dtype))
else:
raise ValueError(
"asarray not implemented for 'NestedSequence' or 'SupportsBufferProtocol'"
)
raise ValueError("asarray not implemented for 'NestedSequence' or 'SupportsBufferProtocol'")


def arange(
Expand Down Expand Up @@ -155,9 +153,7 @@ def empty(
)


def empty_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
"""
Return a new array whose shape and dtype match the input array, without initializing entries.
"""
Expand Down Expand Up @@ -217,17 +213,8 @@ def eye(
if n_cols is not None:
cols = n_cols

repMsg = generic_msg(
cmd="eye",
args={
"dtype": np.dtype(dtype).name,
"rows": n_rows,
"cols": cols,
"diag": k,
},
)

return Array._new(create_pdarray(repMsg))
from arkouda import dtype as akdtype
return Array._new(ak.eye(rows=n_rows, cols=cols, diag=k, dt=akdtype(dtype)))


def from_dlpack(x: object, /) -> Array:
Expand Down Expand Up @@ -312,9 +299,7 @@ def ones(
return a


def ones_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
"""
Return a new array whose shape and dtype match the input array, filled with ones.
"""
Expand All @@ -328,33 +313,18 @@ def tril(x: Array, /, *, k: int = 0) -> Array:
"""
from .array_object import Array

repMsg = generic_msg(
cmd=f"tril{x._array.ndim}D",
args={
"array": x._array.name,
"diag": k,
},
)

return Array._new(create_pdarray(repMsg))
return Array._new(ak.tril(x._array, diag=k))


def triu(x: Array, /, *, k: int = 0) -> Array:
"""
Create a new array with the values from `x` above the `k`-th diagonal, and
all other elements zero.
"""
from .array_object import Array

repMsg = generic_msg(
cmd=f"triu{x._array.ndim}D",
args={
"array": x._array.name,
"diag": k,
},
)
from .array_object import Array

return Array._new(create_pdarray(repMsg))
return Array._new(ak.triu(x._array, k))


def zeros(
Expand All @@ -372,31 +342,14 @@ def zeros(
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")

if isinstance(shape, tuple):
if shape == ():
return Array._new(scalar_array(0, dtype=dtype))
else:
ndim = len(shape)
else:
if shape == 0:
return Array._new(scalar_array(0, dtype=dtype))
else:
ndim = 1

dtype = akdtype(dtype) # normalize dtype
dtype_name = cast(np.dtype, dtype).name
return_dtype = akdtype(dtype)
if dtype is None:
return_dtype = akdtype(ak.float64)

repMsg = generic_msg(
cmd=f"create<{dtype_name},{ndim}>",
args={"shape": shape},
)
return Array._new(ak.zeros(shape, return_dtype))

return Array._new(create_pdarray(repMsg))


def zeros_like(
x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
) -> Array:
def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
"""
Return a new array whose shape and dtype match the input array, filled with zeros.
"""
Expand Down
24 changes: 4 additions & 20 deletions arkouda/array_api/elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ def bitwise_and(x1: Array, x2: Array, /) -> Array:
"""
Compute the element-wise bitwise AND of two arrays.
"""
if (
x1.dtype not in _integer_or_boolean_dtypes
or x2.dtype not in _integer_or_boolean_dtypes
):
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
Expand Down Expand Up @@ -141,10 +138,7 @@ def bitwise_or(x1: Array, x2: Array, /) -> Array:
"""
Compute the element-wise bitwise OR of two arrays.
"""
if (
x1.dtype not in _integer_or_boolean_dtypes
or x2.dtype not in _integer_or_boolean_dtypes
):
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
Expand All @@ -169,10 +163,7 @@ def bitwise_xor(x1: Array, x2: Array, /) -> Array:
"""
Compute the element-wise bitwise XOR of two arrays.
"""
if (
x1.dtype not in _integer_or_boolean_dtypes
or x2.dtype not in _integer_or_boolean_dtypes
):
if x1.dtype not in _integer_or_boolean_dtypes or x2.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor")
# Call result type here just to raise on disallowed type combinations
_result_type(x1.dtype, x2.dtype)
Expand Down Expand Up @@ -410,14 +401,7 @@ def logical_not(x: Array, /) -> Array:
"""
Compute the element-wise logical NOT of a boolean array.
"""
repMsg = ak.generic_msg(
cmd=f"efunc{x._array.ndim}D",
args={
"func": "not",
"array": x._array,
},
)
return Array._new(ak.create_pdarray(repMsg))
return ~x


def logical_or(x1: Array, x2: Array, /) -> Array:
Expand Down
65 changes: 9 additions & 56 deletions arkouda/array_api/linalg.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,15 @@
from .array_object import Array

from arkouda.client import generic_msg
from arkouda.pdarrayclass import create_pdarray, broadcast_if_needed


def matmul(x1: Array, x2: Array, /) -> Array:
"""
Matrix product of two arrays.
"""
from .array_object import Array

if x1._array.ndim < 2 and x2._array.ndim < 2:
raise ValueError(
"matmul requires at least one array argument to have more than two dimensions"
)

x1b, x2b, tmp_x1, tmp_x2 = broadcast_if_needed(x1._array, x2._array)
from arkouda import matmul as ak_matmul

repMsg = generic_msg(
cmd=f"matMul{len(x1b.shape)}D",
args={
"x1": x1b.name,
"x2": x2b.name,
},
)

if tmp_x1:
del x1b
if tmp_x2:
del x2b
from .array_object import Array

return Array._new(create_pdarray(repMsg))
return Array._new(ak_matmul(x1._array, x2._array))


def tensordot():
Expand All @@ -44,42 +23,16 @@ def matrix_transpose(x: Array) -> Array:
"""
Matrix product of two arrays.
"""
from .array_object import Array
from arkouda import transpose as ak_transpose

if x._array.ndim < 2:
raise ValueError(
"matrix_transpose requires the array to have more than two dimensions"
)

repMsg = generic_msg(
cmd=f"transpose{x._array.ndim}D",
args={
"array": x._array.name,
},
)

return Array._new(create_pdarray(repMsg))


def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
from .array_object import Array

x1b, x2b, tmp_x1, tmp_x2 = broadcast_if_needed(x1._array, x2._array)
return Array._new(ak_transpose(x._array))

repMsg = generic_msg(
cmd=f"vecdot{len(x1b.shape)}D",
args={
"x1": x1b.name,
"x2": x2b.name,
"bcShape": x1b.shape,
"axis": axis,
},
)

if tmp_x1:
del x1b
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
from arkouda import vecdot as ak_vecdot

if tmp_x2:
del x2b
from .array_object import Array

return Array._new(create_pdarray(repMsg))
return Array._new(ak_vecdot(x1._array, x2._array))
8 changes: 7 additions & 1 deletion arkouda/pdarraycreation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from arkouda.pdarrayclass import create_pdarray, pdarray
from arkouda.strings import Strings


__all__ = [
"array",
"zeros",
Expand Down Expand Up @@ -284,6 +283,7 @@ def array(
raise RuntimeError(f"Unhandled dtype {a.dtype}")
else:
from arkouda.util import _infer_shape_from_size

shape, ndim, full_size = _infer_shape_from_size(a.shape)

# Do not allow arrays that are too large
Expand Down Expand Up @@ -478,11 +478,15 @@ def zeros(
if dtype_name not in NumericDTypes:
raise TypeError(f"unsupported dtype {dtype}")
from arkouda.util import _infer_shape_from_size

shape, ndim, full_size = _infer_shape_from_size(size)

if ndim > get_max_array_rank():
raise ValueError(f"array rank {ndim} exceeds maximum of {get_max_array_rank()}")

if shape == ():
return scalar_array(0, dtype=dtype)

repMsg = generic_msg(cmd=f"create<{dtype_name},{ndim}>", args={"shape": shape})

return create_pdarray(repMsg, max_bits=max_bits)
Expand Down Expand Up @@ -538,6 +542,7 @@ def ones(
if dtype_name not in NumericDTypes:
raise TypeError(f"unsupported dtype {dtype}")
from arkouda.util import _infer_shape_from_size

shape, ndim, full_size = _infer_shape_from_size(size)

if ndim > get_max_array_rank():
Expand Down Expand Up @@ -607,6 +612,7 @@ def full(
if dtype_name not in NumericDTypes:
raise TypeError(f"unsupported dtype {dtype}")
from arkouda.util import _infer_shape_from_size

shape, ndim, full_size = _infer_shape_from_size(size)

if ndim > get_max_array_rank():
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ testpaths =
tests/array_api/array_creation.py
tests/array_api/array_manipulation.py
tests/array_api/binary_ops.py
tests/array_api/elementwise_functions.py
tests/array_api/indexing.py
tests/array_api/linalg.py
tests/array_api/searching_functions.py
tests/array_api/set_functions.py
tests/array_api/sorting.py
Expand Down
Loading

0 comments on commit ac22c11

Please sign in to comment.