Skip to content

Commit

Permalink
fix manipulation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 10, 2023
1 parent dfd450a commit af02c6a
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions jax/experimental/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,14 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[i

def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
"""Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by axis."""
return jax.lax.expand_dims(x, dimensions=[axis])
if axis < -x.ndim - 1 or axis > x.ndim:
raise IndexError(f"{axis=} is out of bounds for array of dimension {x.ndim}")
return jax.numpy.expand_dims(x, axis=axis)


def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
"""Reverses the order of elements in an array along the given axis."""
if axis is None:
dimensions = tuple(range(x.ndim))
elif isinstance(axis, int):
dimensions = (axis,)
elif isinstance(axis, tuple):
dimensions = tuple(operator.index(ax) for ax in axis)
else:
raise TypeError(f"Unexpected input axis={axis}: expected None, int, or tuple of ints")
return jax.lax.rev(x, dimensions=dimensions)
return jax.numpy.flip(x, axis=axis)


def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
Expand All @@ -65,7 +59,7 @@ def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array:
def reshape(x: Array, /, shape: Tuple[int, ...], *, copy: Optional[bool] = None) -> Array:
"""Reshapes an array without changing its data."""
del copy # unused
return jax.lax.reshape(x, shape)
return jax.numpy.reshape(x, shape)


def roll(x: Array, /, shift: Union[int, Tuple[int]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Array:
Expand Down

0 comments on commit af02c6a

Please sign in to comment.