Skip to content

Commit

Permalink
Deprecate newshape argument of jnp.reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed May 9, 2024
1 parent 11da3df commit 708df0f
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations & removals
* The ``kind`` argument to {func}`jax.numpy.sort` and {func}`jax.numpy.argsort`
is now removed. Use `stable=True` or `stable=False` instead.
* The ``newshape`` argument to {func}`jax.numpy.reshape`is being deprecated
and will soon be removed. Use `shape` instead.

## jaxlib 0.4.28

Expand Down
29 changes: 25 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,15 +893,17 @@ def isrealobj(x: Any) -> bool:
return not iscomplexobj(x)


def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
def reshape(
a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *,
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg()) -> Array:
"""Return a reshaped copy of an array.
JAX implementation of :func:`numpy.reshape`, implemented in terms of
:func:`jax.lax.reshape`.
Args:
a: input array to reshape
newshape: integer or sequence of integers giving the new shape, which must match the
shape: integer or sequence of integers giving the new shape, which must match the
size of the input array. If any single dimension is given size ``-1``, it will be
replaced with a value such that the output has the correct size.
order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major
Expand Down Expand Up @@ -961,12 +963,31 @@ def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
"""
__tracebackhide__ = True
util.check_arraylike("reshape", a)

# TODO(micky774): deprecated 2024-5-9, remove after deprecation expires.
if not isinstance(newshape, DeprecatedArg):
if shape is not None:
raise ValueError(
"jnp.reshape received both `shape` and `newshape` arguments. Note that "
"using `newshape` is deprecated, please only use `shape` instead."
)
warnings.warn(
"The newshape argument of jax.numpy.reshape is deprecated and setting it "
"will soon raise an error. To avoid an error in the future, and to "
"suppress this warning, please use the shape argument instead.",
DeprecationWarning, stacklevel=2)
shape = newshape
del newshape
elif shape is None:
raise TypeError(
"jnp.shape requires passing a `shape` argument, but none was given."
)
try:
# forward to method for ndarrays
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr]
except AttributeError:
pass
return asarray(a).reshape(newshape, order=order)
return asarray(a).reshape(shape, order=order)


@partial(jit, static_argnames=('order',), inline=True)
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from jax import Array


# TODO(micky774): Deprecate newshape-->shape in for array API 2023.12
# TODO(micky774): Implement copy
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
"""Reshapes an array without changing its data."""
del copy # unused
Expand Down
3 changes: 2 additions & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,8 @@ def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = ..., *,
total_repeat_length: Optional[int] = ...) -> Array: ...
def reshape(
a: ArrayLike, newshape: Union[DimSize, Shape], order: str = ...
a: ArrayLike, shape: Union[DimSize, Shape] = ...,
newshape: Union[DimSize, Shape] | None = ..., order: str = ...
) -> Array: ...

def resize(a: ArrayLike, new_shape: Shape) -> Array: ...
Expand Down
8 changes: 4 additions & 4 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,9 +1221,9 @@ def test_reshape(self):
key = random.key(123)
keys = random.split(key, 4)

newshape = (2, 2)
key_func = partial(jnp.reshape, newshape=newshape)
arr_func = partial(jnp.reshape, newshape=(*newshape, *key._impl.key_shape))
shape = (2, 2)
key_func = partial(jnp.reshape, shape=shape)
arr_func = partial(jnp.reshape, shape=(*shape, *key._impl.key_shape))

self.check_shape(key_func, keys)
self.check_against_reference(key_func, arr_func, keys)
Expand Down Expand Up @@ -1291,7 +1291,7 @@ def test_ravel(self):
keys = random.split(key, 4).reshape(2, 2)

key_func = jnp.ravel
arr_func = partial(jnp.reshape, newshape=(4, *key._impl.key_shape))
arr_func = partial(jnp.reshape, shape=(4, *key._impl.key_shape))

self.check_shape(key_func, keys)
self.check_against_reference(key_func, arr_func, keys)
Expand Down

0 comments on commit 708df0f

Please sign in to comment.