Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate newshape argument of jnp.reshape #21130

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Remember to align the itemized text with the first line of an item within a list
* Removed ``get_compute_capability`` from the ``jax.experimental.pallas.gpu``
module. Use the ``compute_capability`` attribute of a GPU device, returned
by {func}`jax.devices` or {func}`jax.local_devices`, instead.
* The ``newshape`` argument to {func}`jax.numpy.reshape`is being deprecated
and will soon be removed. Use `shape` instead.

* Changes
* The minimum jaxlib version of this release is 0.4.27.
Expand Down
29 changes: 25 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,15 +893,17 @@ def isrealobj(x: Any) -> bool:
return not iscomplexobj(x)


def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
def reshape(
a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *,
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg()) -> Array:
Micky774 marked this conversation as resolved.
Show resolved Hide resolved
"""Return a reshaped copy of an array.

JAX implementation of :func:`numpy.reshape`, implemented in terms of
:func:`jax.lax.reshape`.

Args:
a: input array to reshape
newshape: integer or sequence of integers giving the new shape, which must match the
shape: integer or sequence of integers giving the new shape, which must match the
size of the input array. If any single dimension is given size ``-1``, it will be
replaced with a value such that the output has the correct size.
order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major
Expand Down Expand Up @@ -961,12 +963,31 @@ def reshape(a: ArrayLike, newshape: DimSize | Shape, order: str = "C") -> Array:
"""
__tracebackhide__ = True
util.check_arraylike("reshape", a)

# TODO(micky774): deprecated 2024-5-9, remove after deprecation expires.
if not isinstance(newshape, DeprecatedArg):
if shape is not None:
raise ValueError(
"jnp.reshape received both `shape` and `newshape` arguments. Note that "
"using `newshape` is deprecated, please only use `shape` instead."
)
warnings.warn(
"The newshape argument of jax.numpy.reshape is deprecated and setting it "
"will soon raise an error. To avoid an error in the future, and to "
"suppress this warning, please use the shape argument instead.",
DeprecationWarning, stacklevel=2)
shape = newshape
del newshape
elif shape is None:
raise TypeError(
"jnp.shape requires passing a `shape` argument, but none was given."
)
try:
# forward to method for ndarrays
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
return a.reshape(shape, order=order) # type: ignore[call-overload,union-attr]
except AttributeError:
pass
return asarray(a).reshape(newshape, order=order)
return asarray(a).reshape(shape, order=order)


@partial(jit, static_argnames=('order',), inline=True)
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from jax import Array


# TODO(micky774): Deprecate newshape-->shape in for array API 2023.12
# TODO(micky774): Implement copy
def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array:
"""Reshapes an array without changing its data."""
del copy # unused
Expand Down
3 changes: 2 additions & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,8 @@ def remainder(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def repeat(a: ArrayLike, repeats: ArrayLike, axis: Optional[int] = ..., *,
total_repeat_length: Optional[int] = ...) -> Array: ...
def reshape(
a: ArrayLike, newshape: Union[DimSize, Shape], order: str = ...
a: ArrayLike, shape: Union[DimSize, Shape] = ...,
newshape: Union[DimSize, Shape] | None = ..., order: str = ...
) -> Array: ...

def resize(a: ArrayLike, new_shape: Shape) -> Array: ...
Expand Down
8 changes: 4 additions & 4 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,9 +1221,9 @@ def test_reshape(self):
key = random.key(123)
keys = random.split(key, 4)

newshape = (2, 2)
key_func = partial(jnp.reshape, newshape=newshape)
arr_func = partial(jnp.reshape, newshape=(*newshape, *key._impl.key_shape))
shape = (2, 2)
key_func = partial(jnp.reshape, shape=shape)
arr_func = partial(jnp.reshape, shape=(*shape, *key._impl.key_shape))

self.check_shape(key_func, keys)
self.check_against_reference(key_func, arr_func, keys)
Expand Down Expand Up @@ -1291,7 +1291,7 @@ def test_ravel(self):
keys = random.split(key, 4).reshape(2, 2)

key_func = jnp.ravel
arr_func = partial(jnp.reshape, newshape=(4, *key._impl.key_shape))
arr_func = partial(jnp.reshape, shape=(4, *key._impl.key_shape))

self.check_shape(key_func, keys)
self.check_against_reference(key_func, arr_func, keys)
Expand Down