Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jnp.reshape: add copy argument for Array API #22768

Merged
merged 1 commit into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,8 @@ def isrealobj(x: Any) -> bool:

def reshape(
a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *,
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg()) -> Array:
newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(),
copy: bool | None = None) -> Array:
"""Return a reshaped copy of an array.

JAX implementation of :func:`numpy.reshape`, implemented in terms of
Expand All @@ -1303,6 +1304,8 @@ def reshape(
order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major
(fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``.
JAX does not support ``order="A"``.
copy: unused by JAX; JAX always returns a copy, though under JIT the compiler
may optimize such copies away.

Returns:
reshaped copy of input array with the specified shape.
Expand Down Expand Up @@ -1355,6 +1358,8 @@ def reshape(
[3, 4],
[5, 6]], dtype=int32)
"""
del copy # unused

__tracebackhide__ = True
util.check_arraylike("reshape", a)

Expand Down
5 changes: 1 addition & 4 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
real as real,
remainder as remainder,
repeat as repeat,
reshape as reshape,
result_type as result_type,
roll as roll,
round as round,
Expand Down Expand Up @@ -188,10 +189,6 @@
zeros_like as zeros_like,
)

from jax.experimental.array_api._manipulation_functions import (
reshape as reshape,
)

from jax.experimental.array_api._data_type_functions import (
astype as astype,
)
Expand Down
25 changes: 0 additions & 25 deletions jax/experimental/array_api/_manipulation_functions.py

This file was deleted.

Loading