From 5198db9fdbbbc0f8bb52d99e28d3877af60ecf68 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 30 Jul 2024 14:07:08 -0700 Subject: [PATCH] jnp.repeat: add copy argument for Array API --- jax/_src/numpy/lax_numpy.py | 7 +++++- jax/experimental/array_api/__init__.py | 5 +--- .../array_api/_manipulation_functions.py | 25 ------------------- 3 files changed, 7 insertions(+), 30 deletions(-) delete mode 100644 jax/experimental/array_api/_manipulation_functions.py diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index c4a9ee17fe40..7ad74c84b424 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1289,7 +1289,8 @@ def isrealobj(x: Any) -> bool: def reshape( a: ArrayLike, shape: DimSize | Shape | None = None, order: str = "C", *, - newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg()) -> Array: + newshape: DimSize | Shape | DeprecatedArg = DeprecatedArg(), + copy: bool | None = None) -> Array: """Return a reshaped copy of an array. JAX implementation of :func:`numpy.reshape`, implemented in terms of @@ -1303,6 +1304,8 @@ def reshape( order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major (fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``. JAX does not support ``order="A"``. + copy: unused by JAX; JAX always returns a copy, though under JIT the compiler + may optimize such copies away. Returns: reshaped copy of input array with the specified shape. @@ -1355,6 +1358,8 @@ def reshape( [3, 4], [5, 6]], dtype=int32) """ + del copy # unused + __tracebackhide__ = True util.check_arraylike("reshape", a) diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 9a0be504f81a..69d1a083058f 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -148,6 +148,7 @@ real as real, remainder as remainder, repeat as repeat, + reshape as reshape, result_type as result_type, roll as roll, round as round, @@ -188,10 +189,6 @@ zeros_like as zeros_like, ) -from jax.experimental.array_api._manipulation_functions import ( - reshape as reshape, -) - from jax.experimental.array_api._data_type_functions import ( astype as astype, ) diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py deleted file mode 100644 index c364b9f5b79c..000000000000 --- a/jax/experimental/array_api/_manipulation_functions.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import jax -from jax import Array - - -# TODO(micky774): Implement copy -def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array: - """Reshapes an array without changing its data.""" - del copy # unused - return jax.numpy.reshape(x, shape)