Skip to content

Commit

Permalink
[Rollback] We still want to allow multiple meshes in the user program
Browse files Browse the repository at this point in the history
Reverts dd958ad

PiperOrigin-RevId: 661537233
  • Loading branch information
yashk2810 authored and jax authors committed Aug 10, 2024
1 parent abc9ba0 commit c08656c
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 66 deletions.
6 changes: 1 addition & 5 deletions jax/_src/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import Any
import weakref

import numpy as np
import jax
from jax import tree_util
from jax._src import api_util
Expand Down Expand Up @@ -482,20 +481,17 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
infer_sharding_from_operands,
decode_shardings,
static_args):
mesh = mesh_lib.thread_resources.env.physical_mesh
axis_context = ctx.module_context.axis_context
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
return mlir.lower_fun(core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values)

mesh = mesh_lib.thread_resources.env.physical_mesh
if isinstance(axis_context, sharding_impls.ShardingContext):
devices = axis_context.device_assignment
if devices is None:
raise AssertionError(
'Please file a bug at https://github.com/google/jax/issues')
if axis_context.mesh_shape is not None:
ma, ms = list(zip(*axis_context.mesh_shape))
mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma)
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = axis_context.mesh._flat_devices_tuple
else:
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,7 @@ def lower_jaxpr_to_module(
input_output_aliases: None | tuple[int | None, ...] = None,
propagated_out_mem_kinds: tuple[None | str, ...] | None = None,
lowering_parameters: LoweringParameters,
mesh_shape_tuple: tuple[tuple[str, int], ...] | None = None,
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MLIR module.
Expand Down Expand Up @@ -1044,14 +1045,13 @@ def lower_jaxpr_to_module(
# XLA computation preserves the module name.
attrs = ctx.module.operation.attributes
if config.use_shardy_partitioner.value:
assert (isinstance(axis_context, sharding_impls.ShardingContext) and
axis_context.mesh_shape is not None)
assert mesh_shape_tuple is not None
ctx.module.body.append(
dialects.sdy.MeshOp(
"mesh",
dialects.sdy.MeshAttr.get(
[dialects.sdy.MeshAxisAttr.get(name, size)
for name, size in axis_context.mesh_shape])))
for name, size in mesh_shape_tuple])))
module_name = _module_name_regex.sub("_", module_name)
attrs["sym_name"] = ir.StringAttr.get(module_name)
attrs["mhlo.num_replicas"] = i32_attr(num_replicas)
Expand Down
23 changes: 11 additions & 12 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,7 +1881,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
propagated_out_mem_kinds: tuple[None | str, ...],
platforms: tuple[str, ...],
lowering_parameters: mlir.LoweringParameters,
mesh_shape_tuple: tuple[tuple[str, int], ...] | None):
mesh_shape_tuple: tuple[tuple[str, int], ...]):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
Expand Down Expand Up @@ -1911,8 +1911,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(global_in_avals)
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment,
mesh_shape=mesh_shape_tuple)
axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment)
num_partitions = num_devices
else:
# This path is triggered for `jit(pmap)` cases.
Expand Down Expand Up @@ -1958,7 +1957,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
all_default_mem_kind=all_default_mem_kind,
input_output_aliases=inout_aliases,
propagated_out_mem_kinds=propagated_out_mem_kinds,
lowering_parameters=lowering_parameters)
lowering_parameters=lowering_parameters,
mesh_shape_tuple=mesh_shape_tuple)
tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
unordered_effects = list(
effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
Expand Down Expand Up @@ -2203,15 +2203,14 @@ def lower_sharding_computation(
semantic_out_shardings = SemanticallyEqualShardings(
out_shardings, global_out_avals) # type: ignore
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)

# TODO(yashkatariya): Initialize with context_mesh here?
mesh_shape_tuple = None
for sharding in it.chain(
in_shardings, out_shardings,
[js for js, _ in unique_intermediate_shardings]):
if isinstance(sharding, sharding_impls.NamedSharding):
mesh_shape_tuple = sharding.mesh.shape_tuple
break
if config.use_shardy_partitioner.value:
for sharding in it.chain(
in_shardings, out_shardings,
[js for js, _ in unique_intermediate_shardings]):
if isinstance(sharding, sharding_impls.NamedSharding):
mesh_shape_tuple = sharding.mesh.shape_tuple
break

(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
Expand Down
1 change: 0 additions & 1 deletion jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,6 @@ class ShardingContext:
"""
num_devices: int
device_assignment: tuple[xc.Device, ...] | None = None
mesh_shape: tuple[tuple[str, int], ...] | None = None

def __post_init__(self):
if self.device_assignment is not None:
Expand Down
17 changes: 17 additions & 0 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,23 @@ def test_jit_cpp_cache_hit(self):
self.assertArraysEqual(out, np_inp @ np_inp.T)
self.assertArraysEqual(out2, np_inp @ np_inp.T)

def test_jit_compilation_cache_hit(self):
mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y"))
inp2 = jax.device_put(
np_inp, GSPMDSharding(tuple(mesh.devices.flat),
s._to_xla_hlo_sharding(inp.ndim),
memory_kind="device")
)

f = jax.jit(lambda x: x @ x.T)

with (jtu.count_pjit_cpp_cache_miss() as cpp_count,
jtu.count_jit_and_pmap_lowerings() as compile_count):
f(inp)
f(inp2)
self.assertEqual(cpp_count[0], 2)
self.assertEqual(compile_count[0], 1)

def test_jit_cpp_cache_output_hit(self):
_, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device")

Expand Down
51 changes: 6 additions & 45 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,45 +1483,6 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
pjit_f = pjit(jit_f, in_shardings=(P('x')), out_shardings=P('x'))
self.assertArraysEqual(x, pjit_f(x))

def test_custom_partitioning_no_mesh_context(self):
self.skip_if_custom_partitioning_not_supported()

@custom_partitioning
def f(x):
return x

def partition(mesh, arg_shapes, result_shape):
def lower_fn(x):
@jax.jit
def g(y):
return y

return g(x)

x_shard = arg_shapes[0].sharding
return (
mesh,
lower_fn,
NamedSharding(x_shard.mesh, P('x')),
(NamedSharding(x_shard.mesh, P('x')),),
)

def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
x_shard = arg_shapes[0].sharding
return NamedSharding(x_shard.mesh, P('x'))

f.def_partition(
infer_sharding_from_operands=infer_sharding_from_operands,
partition=partition,
)

mesh = jtu.create_global_mesh((4,), ('x',))
x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32)
s = NamedSharding(mesh, P('x'))

pjit_f = jax.jit(f, in_shardings=s, out_shardings=s)
self.assertArraysEqual(x, pjit_f(x))

@jtu.with_mesh([('x', 4)])
def test_custom_partitioner_with_scan(self):
self.skip_if_custom_partitioning_not_supported()
Expand Down Expand Up @@ -3448,8 +3409,8 @@ def mul(x):
cache_info4 = pxla._cached_compilation.cache_info()
self.assertIsInstance(out4.sharding, PositionalSharding)

self.assertEqual(cache_info4.hits, cache_info3.hits)
self.assertEqual(cache_info4.misses, cache_info3.misses + 1)
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
self.assertEqual(cache_info4.misses, cache_info3.misses)

def test_cache_hit_pjit_lower_with_cpp_cache_miss(self):
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
Expand Down Expand Up @@ -3560,8 +3521,8 @@ def test_jit_mul_sum_sharding_preserved(self):
self.assertIsInstance(out3.sharding, PositionalSharding)
self.assertEqual(count[0], 1)

self.assertEqual(cache_info2.hits, cache_info1.hits)
self.assertEqual(cache_info2.misses, cache_info1.misses + 1)
self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)

self.assertEqual(pl_cache_info2.hits, pl_cache_info1.hits)
self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1)
Expand Down Expand Up @@ -3853,7 +3814,7 @@ def test_lowering_cache_hit_different_devices(self):
self.skipTest('Requires >=4 devices')

mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x')
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'x')
mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y')

@jax.jit
def f(x):
Expand All @@ -3864,7 +3825,7 @@ def g(a):
out_a = f(a) # lowering cached

# same num_devices but different devices.
b = jax.device_put(out_a, NamedSharding(mesh2, P('x')))
b = jax.device_put(out_a, NamedSharding(mesh2, P('y')))
f(b) # lowering cache *hit*

with jtu.count_jit_and_pmap_lowerings() as count:
Expand Down

0 comments on commit c08656c

Please sign in to comment.