Skip to content

Commit

Permalink
[sharding_in_types] Add dot_general sharding rule. We only handle t…
Browse files Browse the repository at this point in the history
…he simple cases and rely on xla to insert the collectives.

Cases where we error

* batch dimensions not having consistent sharding (ignore None)
* contracting dimensions not having consistent sharding (ignore None)
* lhs.mesh != rhs.mesh
* if batch dimension and tensor dimension sharding match -> Error

PiperOrigin-RevId: 684983567
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 11, 2024
1 parent a2973be commit 18bc354
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 9 deletions.
45 changes: 43 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3213,6 +3213,46 @@ def _dot_general_shape_computation(lhs_shape, rhs_shape, dimension_numbers):
rhs_tensored_shape = tuple_delete(rhs_shape, rhs_contract_or_batch)
return batch_shape + lhs_tensored_shape + rhs_tensored_shape


def _check_specs_match(lhs_spec, rhs_spec, msg):
for l, r in zip(lhs_spec, rhs_spec):
if l is not None and r is not None and l != r:
raise TypeError(msg)

def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision,
preferred_element_type: DTypeLike | None):
if lhs.sharding.mesh != rhs.sharding.mesh:
raise ValueError(
'Mesh of both lhs and rhs should match. Got lhs:'
f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}')

(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch)
rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch)
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
f"to have the consistent sharding, got {lhs_batch_spec} and "
f"{rhs_batch_spec}.")
_check_specs_match(lhs_batch_spec, rhs_batch_spec, msg)

lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting)
rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting)
msg = ("dot_general requires contracting dimensions to have consistent "
f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.")
_check_specs_match(lhs_contracting_spec, rhs_contracting_spec, msg)

return _dot_general_sharding_computation(
lhs.sharding.spec, rhs.sharding.spec, dimension_numbers, lhs.sharding.mesh)

def _dot_general_sharding_computation(lhs_spec, rhs_spec,
dimension_numbers, mesh):
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
batch_spec = tuple(lhs_spec[i] for i in lhs_batch)
lhs_contract_or_batch = tuple(sorted(tuple(lhs_contracting) + tuple(lhs_batch)))
lhs_tensored_spec = tuple_delete(lhs_spec, lhs_contract_or_batch)
rhs_contract_or_batch = tuple(sorted(tuple(rhs_contracting) + tuple(rhs_batch)))
rhs_tensored_spec = tuple_delete(rhs_spec, rhs_contract_or_batch)
return NamedSharding(mesh, P(*(batch_spec + lhs_tensored_spec + rhs_tensored_spec)))

def tuple_delete(tup, idx):
idx_ = set(idx)
return tuple(tup[i] for i in range(len(tup)) if i not in idx_)
Expand Down Expand Up @@ -3419,8 +3459,9 @@ def _dot_general_pp_rule(eqn, context, settings) -> pp.Doc:
(list(lhs_cont), list(rhs_cont)), (list(lhs_batch), list(rhs_batch)))
return core._pp_eqn(eqn.replace(params=printed_params), context, settings)

dot_general_p = standard_primitive(_dot_general_shape_rule,
_dot_general_dtype_rule, 'dot_general')
dot_general_p = standard_primitive(
_dot_general_shape_rule, _dot_general_dtype_rule, 'dot_general',
sharding_rule=_dot_general_sharding_rule)


def _dot_general_batch_unpack_args(batch_args):
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
out = prim.impl(*[x.val for x in avals], **kwargs)
return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
elif least_specialized is core.ShapedArray:
out_sharding = (sharding_rule(*avals, **kwargs)
if config.sharding_in_types.value else None)
return core.ShapedArray(shape_rule(*avals, **kwargs),
dtype_rule(*avals, **kwargs), weak_type=weak_type,
sharding=out_sharding)
return core.ShapedArray(
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
weak_type=weak_type,
sharding=(sharding_rule(*avals, **kwargs)
if config.sharding_in_types.value else None))
elif least_specialized is core.DShapedArray:
shape = shape_rule(*avals, **kwargs)
ty = (core.ShapedArray if all(type(d) is int for d in shape)
Expand Down
71 changes: 69 additions & 2 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4603,9 +4603,9 @@ def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")


@jtu.with_config(jax_sharding_in_types=True)
class ShardingInTypesTest(jtu.JaxTestCase):

@config.sharding_in_types(True)
def test_basic_mul(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
Expand All @@ -4631,7 +4631,6 @@ def f(x):
else:
self.assertEqual(lowered_text.count('@Sharding'), 2)

@config.sharding_in_types(True)
def test_fully_replicated_array_mul(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp1 = np.arange(16).reshape(8, 2)
Expand Down Expand Up @@ -4672,6 +4671,74 @@ def g(x, y):
TypeError, "mul got incompatible shardings for broadcasting"):
g(arr1, jax.device_put(np_inp1, NamedSharding(mesh, P(('x', 'y')))))

@parameterized.named_parameters(
('x_y', P('x', None), P(None, 'y'), P('x', 'y'), None),
('x_None', P('x', None), P(None, None), P('x', None), None),
('contracting1', P('x', 'y'), P('y', None), P('x', None), 'all-reduce'),
('contracting2', P('x', 'y'), P(None, None), P('x', None), 'all-gather'),
('fsdp', P('x', None), P('x', None), P('x', None), 'all-gather'),
('half_tp', P(None, 'y'), P(None, 'y'), P(None, 'y'), 'all-gather'),
('other_half_tp', P(None, 'y'), P('y', None), P(None, None), 'all-reduce')
)
def test_dot_general_basic(self, spec1, spec2, out_spec, collective_name):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp1 = np.arange(16).reshape(8, 2)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2))

@jax.jit
def f(x, y):
out = x @ y
self.assertEqual(out.sharding.spec, out_spec)
return out

out = f(arr1, arr2)
self.assertArraysEqual(out, np_inp1 @ np_inp1.T)
self.assertEqual(out.aval.sharding.spec, out_spec)

compiled_text = f.lower(arr1, arr2).compile().as_text()
if collective_name is not None and compiled_text is not None:
self.assertIn(collective_name, compiled_text)

@parameterized.named_parameters(
('fail1', P('x', 'y'), P('y', 'x'),
"PartitionSpec.*x.*x.*has duplicate entries", ValueError),
('fail2', P('x', 'y'), P('x', 'y'),
"dot_general requires contracting dimensions to have consistent sharding",
TypeError),
)
def test_dot_general_error(self, spec1, spec2, error_msg, error_type):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp1 = np.arange(16).reshape(8, 2)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, spec1))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, spec2))

@jax.jit
def f(x, y):
return x @ y

with self.assertRaisesRegex(error_type, error_msg):
f(arr1, arr2)

def test_dot_general_batch_error(self):
mesh = jtu.create_mesh((2, 2, 1), ('x', 'y', 'z'))
arr1 = jax.device_put(np.ones((8, 4, 2)),
NamedSharding(mesh, P('x', 'y', 'z')))
arr2 = jax.device_put(np.ones((8, 2, 4)),
NamedSharding(mesh, P('y', 'z', 'x')))
with self.assertRaisesRegex(
TypeError,
'dot_general requires lhs batch dimensions and rhs batch dimensions to'
' have the consistent sharding'):
jax.lax.dot_general(
arr1, arr2, dimension_numbers=(([2], [1]), ([0], [0])))

with self.assertRaisesRegex(
TypeError,
'dot_general requires lhs batch dimensions and rhs batch dimensions to'
' have the consistent sharding'):
jnp.einsum('abc,acz->abz', arr1, arr2)


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 18bc354

Please sign in to comment.