Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[array api] add jax.numpy.concat #19323

Merged
merged 1 commit into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ namespace; they are listed below.
complexfloating
ComplexWarning
compress
concat
concatenate
conj
conjugate
Expand Down
10 changes: 8 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,10 +1869,10 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
util.check_arraylike("concatenate", *arrays)
if not len(arrays):
raise ValueError("Need at least one array to concatenate.")
if ndim(arrays[0]) == 0:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
if axis is None:
return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
if ndim(arrays[0]) == 0:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
axis = _canonicalize_axis(axis, ndim(arrays[0]))
if dtype is None:
arrays_out = util.promote_dtypes(*arrays)
Expand All @@ -1888,6 +1888,12 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
return arrays_out[0]


@util._wraps(getattr(np, "concat", None))
def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array:
util.check_arraylike("concat", *arrays)
return jax.numpy.concatenate(arrays, axis=axis)


@util._wraps(np.vstack)
def vstack(tup: np.ndarray | Array | Sequence[ArrayLike],
dtype: DTypeLike | None = None) -> Array:
Expand Down
5 changes: 1 addition & 4 deletions jax/experimental/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ def broadcast_to(x: Array, /, shape: tuple[int]) -> Array:
def concat(arrays: tuple[Array, ...] | list[Array], /, *, axis: int | None = 0) -> Array:
"""Joins a sequence of arrays along an existing axis."""
dtype = _result_type(*arrays)
if axis is None:
arrays = [reshape(arr, (arr.size,)) for arr in arrays]
axis = 0
return jax.numpy.concatenate(arrays, axis=axis, dtype=dtype)
return jax.numpy.concat([arr.astype(dtype) for arr in arrays], axis=axis)


def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
complex_ as complex_,
complexfloating as complexfloating,
compress as compress,
concat as concat,
concatenate as concatenate,
convolve as convolve,
copy as copy,
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ complex_: Any
complexfloating = _np.complexfloating
def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = ...,
out: None = ...) -> Array: ...
def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: ...
def concatenate(
arrays: Union[_np.ndarray, Array, Sequence[ArrayLike]],
axis: Optional[int] = ...,
Expand Down
32 changes: 30 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,7 @@ def testCompressMethod(self, shape, dtype, axis):
@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(-len(base_shape)+1, len(base_shape))
for axis in (None, *range(-len(base_shape)+1, len(base_shape)))
],
arg_dtypes=[
arg_dtypes
Expand All @@ -1482,7 +1482,7 @@ def testCompressMethod(self, shape, dtype, axis):
)
def testConcatenate(self, axis, dtype, base_shape, arg_dtypes):
rng = jtu.rand_default(self.rng())
wrapped_axis = axis % len(base_shape)
wrapped_axis = 0 if axis is None else axis % len(base_shape)
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)]
@jtu.promote_like_jnp
Expand Down Expand Up @@ -1521,6 +1521,34 @@ def testConcatenateAxisNone(self):
b = jnp.array([[5]])
jnp.concatenate((a, b), axis=None)

def testConcatenateScalarAxisNone(self):
arrays = [np.int32(0), np.int32(1)]
self.assertArraysEqual(jnp.concatenate(arrays, axis=None),
np.concatenate(arrays, axis=None))

@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
for base_shape in [(), (4,), (3, 4), (2, 3, 4)]
for axis in (None, *range(-len(base_shape)+1, len(base_shape)))
],
dtype=default_dtypes,
)
def testConcat(self, axis, base_shape, dtype):
rng = jtu.rand_default(self.rng())
wrapped_axis = 0 if axis is None else axis % len(base_shape)
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
for size in [3, 1, 4]]
@jtu.promote_like_jnp
def np_fun(*args):
if jtu.numpy_version() >= (2, 0, 0):
return np.concat(args, axis=axis)
else:
return np.concatenate(args, axis=axis)
jnp_fun = lambda *args: jnp.concat(args, axis=axis)
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
Expand Down