diff --git a/jax/_src/core.py b/jax/_src/core.py index cc290e20a889..5965ec57e95b 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1817,7 +1817,7 @@ def str_short(self, short_dtypes=False): dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name) dt_str = dt_str.replace('void', 'float0') - if hasattr(self, 'sharding'): + if hasattr(self, 'sharding') and self.sharding is not None: shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec)) return f'{dt_str}[{shapestr}]' else: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index fcf766357c25..00db56bf563a 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -3928,9 +3928,15 @@ def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions): msg = ("broadcast_in_dim broadcast_dimensions must be strictly increasing; " "got broadcast_dimensions {}") raise TypeError(msg.format(broadcast_dimensions)) - return shape +def _broadcast_in_dim_sharding_rule(operand, *, shape, broadcast_dimensions): + bds = set(broadcast_dimensions) + orig_spec = iter(operand.sharding.spec) + new_spec = [next(orig_spec) if i in bds else None for i in range(len(shape))] + assert next(orig_spec, None) is None + return NamedSharding(operand.sharding.mesh, P(*new_spec)) + def _broadcast_in_dim_typecheck_rule( _, operand, *dyn_shape, shape, broadcast_dimensions): if not dyn_shape: @@ -4079,10 +4085,12 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions) -> aval_out, = ctx.avals_out if dyn_shape: aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape)) - - - return [mlir.broadcast_in_dim(ctx, x, aval_out, - broadcast_dimensions=broadcast_dimensions)] + out = mlir.broadcast_in_dim(ctx, x, aval_out, + broadcast_dimensions=broadcast_dimensions) + if config.sharding_in_types.value: + proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] + return [out] def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions): if (not dyn_shape and @@ -4090,7 +4098,12 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions): type(core.get_aval(d).dtype) is core.bint for d in shape)): shape = _broadcast_in_dim_shape_rule( # error checking x, shape=shape, broadcast_dimensions=broadcast_dimensions) - return core.ShapedArray(shape, x.dtype, x.weak_type) + if config.sharding_in_types.value: + sharding = _broadcast_in_dim_sharding_rule( + x, shape=shape, broadcast_dimensions=broadcast_dimensions) + else: + sharding = None + return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=sharding) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray # (even if x is a ShapedArray) # TODO(mattjj): unify DShapedArray with ShapedArray, and remove this code diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 08229c331454..5fe89750e216 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4784,6 +4784,33 @@ def f(x): if reduce and compiled_text is not None: self.assertIn('all-reduce', compiled_text) + @parameterized.named_parameters( + ('0', 0, P(None, 'x', 'y')), + ('1', 1, P('x', None, 'y')), + ('2', 2, P('x', 'y', None)), + ('-1', -1, P('x', 'y', None)), + ) + def test_broadcast_in_dim(self, axis, out_spec): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + out = jnp.expand_dims(arr, axis=axis) + self.assertEqual(out.aval.sharding.spec, out_spec) + + @jax.jit + def f(x): + y = jnp.expand_dims(x, axis=axis) + self.assertEqual(y.sharding.spec, out_spec) + return y + + out = f(arr) + self.assertEqual(out.aval.sharding.spec, out_spec) + + lowered_text = f.lower(arr).as_text() + self.assertIn('@Sharding', lowered_text) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):