From af02c6a08e6d22377747e9f97ac94b2356fa736e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 10 Nov 2023 14:51:04 -0800 Subject: [PATCH] fix manipulation functions --- .../array_api/_manipulation_functions.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index 8d1bdf8d3b24..411476f229d7 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -41,20 +41,14 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i def expand_dims(x: Array, /, *, axis: int = 0) -> Array: """Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by axis.""" - return jax.lax.expand_dims(x, dimensions=[axis]) + if axis < -x.ndim - 1 or axis > x.ndim: + raise IndexError(f"{axis=} is out of bounds for array of dimension {x.ndim}") + return jax.numpy.expand_dims(x, axis=axis) def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: """Reverses the order of elements in an array along the given axis.""" - if axis is None: - dimensions = tuple(range(x.ndim)) - elif isinstance(axis, int): - dimensions = (axis,) - elif isinstance(axis, tuple): - dimensions = tuple(operator.index(ax) for ax in axis) - else: - raise TypeError(f"Unexpected input axis={axis}: expected None, int, or tuple of ints") - return jax.lax.rev(x, dimensions=dimensions) + return jax.numpy.flip(x, axis=axis) def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: @@ -65,7 +59,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array: """Reshapes an array without changing its data.""" del copy # unused - return jax.lax.reshape(x, shape) + return jax.numpy.reshape(x, shape) def roll(x: Array, /, shift: Union[int, Tuple[int]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array: