diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index adb8131a49d7..e47858a7c584 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -121,6 +121,7 @@ namespace; they are listed below. complexfloating ComplexWarning compress + concat concatenate conj conjugate diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index c9dcb674fadd..29479cb60a68 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1869,10 +1869,10 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], util.check_arraylike("concatenate", *arrays) if not len(arrays): raise ValueError("Need at least one array to concatenate.") - if ndim(arrays[0]) == 0: - raise ValueError("Zero-dimensional arrays cannot be concatenated.") if axis is None: return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype) + if ndim(arrays[0]) == 0: + raise ValueError("Zero-dimensional arrays cannot be concatenated.") axis = _canonicalize_axis(axis, ndim(arrays[0])) if dtype is None: arrays_out = util.promote_dtypes(*arrays) @@ -1888,6 +1888,12 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], return arrays_out[0] +@util._wraps(getattr(np, "concat", None)) +def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: + util.check_arraylike("concat", *arrays) + return jax.numpy.concatenate(arrays, axis=axis) + + @util._wraps(np.vstack) def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: diff --git a/jax/experimental/array_api/_manipulation_functions.py b/jax/experimental/array_api/_manipulation_functions.py index ff52949ee2fd..4b016db05880 100644 --- a/jax/experimental/array_api/_manipulation_functions.py +++ b/jax/experimental/array_api/_manipulation_functions.py @@ -32,10 +32,7 @@ def broadcast_to(x: Array, /, shape: tuple[int]) -> Array: def concat(arrays: tuple[Array, ...] | list[Array], /, *, axis: int | None = 0) -> Array: """Joins a sequence of arrays along an existing axis.""" dtype = _result_type(*arrays) - if axis is None: - arrays = [reshape(arr, (arr.size,)) for arr in arrays] - axis = 0 - return jax.numpy.concatenate(arrays, axis=axis, dtype=dtype) + return jax.numpy.concat([arr.astype(dtype) for arr in arrays], axis=axis) def expand_dims(x: Array, /, *, axis: int = 0) -> Array: diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index b305b8a5e2c0..2bdb4724b527 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -64,6 +64,7 @@ complex_ as complex_, complexfloating as complexfloating, compress as compress, + concat as concat, concatenate as concatenate, convolve as convolve, copy as copy, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 8b8d0be502a9..92848a8efb69 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -186,6 +186,7 @@ complex_: Any complexfloating = _np.complexfloating def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = ..., out: None = ...) -> Array: ... +def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: ... def concatenate( arrays: Union[_np.ndarray, Array, Sequence[ArrayLike]], axis: Optional[int] = ..., diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 6f725d93483e..7f8ad632e114 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1471,7 +1471,7 @@ def testCompressMethod(self, shape, dtype, axis): @jtu.sample_product( [dict(base_shape=base_shape, axis=axis) for base_shape in [(4,), (3, 4), (2, 3, 4)] - for axis in range(-len(base_shape)+1, len(base_shape)) + for axis in (None, *range(-len(base_shape)+1, len(base_shape))) ], arg_dtypes=[ arg_dtypes @@ -1482,7 +1482,7 @@ def testCompressMethod(self, shape, dtype, axis): ) def testConcatenate(self, axis, dtype, base_shape, arg_dtypes): rng = jtu.rand_default(self.rng()) - wrapped_axis = axis % len(base_shape) + wrapped_axis = 0 if axis is None else axis % len(base_shape) shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)] @jtu.promote_like_jnp @@ -1521,6 +1521,34 @@ def testConcatenateAxisNone(self): b = jnp.array([[5]]) jnp.concatenate((a, b), axis=None) + def testConcatenateScalarAxisNone(self): + arrays = [np.int32(0), np.int32(1)] + self.assertArraysEqual(jnp.concatenate(arrays, axis=None), + np.concatenate(arrays, axis=None)) + + @jtu.sample_product( + [dict(base_shape=base_shape, axis=axis) + for base_shape in [(), (4,), (3, 4), (2, 3, 4)] + for axis in (None, *range(-len(base_shape)+1, len(base_shape))) + ], + dtype=default_dtypes, + ) + def testConcat(self, axis, base_shape, dtype): + rng = jtu.rand_default(self.rng()) + wrapped_axis = 0 if axis is None else axis % len(base_shape) + shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:] + for size in [3, 1, 4]] + @jtu.promote_like_jnp + def np_fun(*args): + if jtu.numpy_version() >= (2, 0, 0): + return np.concat(args, axis=axis) + else: + return np.concatenate(args, axis=axis) + jnp_fun = lambda *args: jnp.concat(args, axis=axis) + args_maker = lambda: [rng(shape, dtype) for shape in shapes] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( [dict(base_shape=base_shape, axis=axis) for base_shape in [(4,), (3, 4), (2, 3, 4)]