Skip to content

Commit

Permalink
Remove physical_hlo_sharding from TyRules.
Browse files Browse the repository at this point in the history
The only caller of `physical_op_sharding` outside of TyRules was mlir.py. This CL also changes lower_jaxpr_to_fun to only accept logical arg_shardings and result_shardings which are XLACompatiableShardings.

PiperOrigin-RevId: 615977878
  • Loading branch information
yashk2810 authored and jax authors committed Mar 15, 2024
1 parent 808455e commit b887346
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 64 deletions.
35 changes: 11 additions & 24 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,16 +819,19 @@ class LoweringResult(NamedTuple):
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]


def _to_logical_op_sharding(
def _to_physical_op_sharding(
aval: core.AbstractValue, sharding: XLACompatibleSharding | None,
) -> xc.HloSharding | None:
) -> xc.OpSharding | None:
if sharding is None:
return None
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
if isinstance(aval, AbstractRef):
return _to_logical_op_sharding(aval.inner_aval, sharding)
return _to_physical_op_sharding(aval.inner_aval, sharding)
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
return sharding._to_xla_hlo_sharding(aval.ndim)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
sharding = aval.dtype._rules.physical_sharding(aval, sharding)
aval = core.physical_aval(aval)
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore


def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None:
Expand Down Expand Up @@ -941,13 +944,6 @@ def lower_jaxpr_to_module(
else:
dim_vars = ()

arg_op_shardings = (
map(_to_logical_op_sharding, jaxpr.in_avals, arg_shardings)
if arg_shardings is not None else arg_shardings)
result_op_shardings = (
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
if result_shardings is not None else result_shardings)

arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None
else in_layouts)
result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None
Expand Down Expand Up @@ -978,8 +974,8 @@ def lower_jaxpr_to_module(
replace_tokens_with_dummy=replace_tokens_with_dummy,
num_output_tokens=0,
replicated_args=replicated_args,
arg_shardings=arg_op_shardings,
result_shardings=result_op_shardings,
arg_shardings=arg_shardings,
result_shardings=result_shardings,
input_output_aliases=input_output_aliases,
xla_donated_args=xla_donated_args,
arg_names=arg_names,
Expand Down Expand Up @@ -1123,8 +1119,8 @@ def lower_jaxpr_to_fun(
public: bool = False,
replace_tokens_with_dummy: bool = False,
replicated_args: Sequence[bool] | None = None,
arg_shardings: Sequence[xc.HloSharding | None] | None = None,
result_shardings: Sequence[xc.HloSharding | None] | None = None,
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
use_sharding_annotations: bool = True,
input_output_aliases: Sequence[int | None] | None = None,
xla_donated_args: Sequence[bool] | None = None,
Expand Down Expand Up @@ -1483,15 +1479,6 @@ def wrap_with_memory_kind(
return op.result


def _to_physical_op_sharding(
aval: core.AbstractValue | None, sharding: xc.HloSharding | None
) -> xc.OpSharding | None:
if (isinstance(aval, core.ShapedArray) and dtypes.issubdtype(aval.dtype, dtypes.extended)
and sharding is not None):
return aval.dtype._rules.physical_hlo_sharding(aval, sharding).to_proto()
return None if sharding is None else sharding.to_proto() # type: ignore


def _emit_lowering_rule_as_fun(lowering_rule,
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
"""Emits the contents of a lowering rule as a private function."""
Expand Down
4 changes: 0 additions & 4 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5110,10 +5110,6 @@ def handler(bufs):
return core.DArray(aval, phys_handler(bufs))
return handler

@staticmethod
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
return hlo_sharding

@staticmethod
def logical_sharding(aval, phys_sharding):
return phys_sharding
Expand Down
6 changes: 2 additions & 4 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,10 +1617,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,

func = mod_ctx.cached_primitive_lowerings.get(key, None)
if func is None:
arg_shardings = [None if is_unspecified(i) else i._to_xla_hlo_sharding(aval.ndim)
for aval, i in zip(ctx.avals_in, in_shardings)]
result_shardings = [None if is_unspecified(o) else o._to_xla_hlo_sharding(aval.ndim)
for aval, o in zip(ctx.avals_out, out_shardings)]
arg_shardings = [None if is_unspecified(i) else i for i in in_shardings]
result_shardings = [None if is_unspecified(o) else o for o in out_shardings]
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
# inputs or outputs because they are lost during MLIR->HLO conversion.
# using_sharding_annotation=False means we add an identity operation instead.
Expand Down
25 changes: 12 additions & 13 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,7 @@ def make_key_array_phys_sharding(aval, sharding):
else:
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
return GSPMDSharding(
sharding._device_assignment,
KeyTyRules.physical_hlo_sharding(aval, hlos))
sharding._device_assignment, physical_hlo_sharding(aval, hlos))


def get_logical_gspmd_sharding(aval, phys_sharding):
Expand All @@ -361,6 +360,17 @@ def get_logical_gspmd_sharding(aval, phys_sharding):
xc.HloSharding.from_proto(logical_op_sharding))


def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
key_shape = aval.dtype._impl.key_shape
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
tad = partitions + [1] * len(key_shape) + suffix
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)


class KeyTyRules:

@staticmethod
Expand All @@ -382,17 +392,6 @@ def physical_element_aval(dtype) -> core.ShapedArray:
def physical_const(val) -> Array:
return val._base_array

@staticmethod
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
key_shape = aval.dtype._impl.key_shape
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
tad = partitions + [1] * len(key_shape) + suffix
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)

@staticmethod
def physical_sharding(
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:
Expand Down
20 changes: 10 additions & 10 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,13 +567,13 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
aval_in, aval_out, x):
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
axes = {name: i for i, ns in names.items() for name in ns}
shard_proto = NamedSharding(
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_hlo_sharding(aval_in.ndim)
ns = NamedSharding(mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
shard_proto = aval_in.dtype._rules.physical_hlo_sharding(aval_in, shard_proto)
ns = aval_in.dtype._rules.physical_sharding(aval_in, ns)
aval_in = core.physical_aval(aval_in)
shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto()
unspecified = set(range(aval_in.ndim)) if auto else set()
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto.to_proto(), # type: ignore
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, # type: ignore
unspecified_dims=unspecified)
return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, set())]

Expand All @@ -583,13 +583,13 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=set())
axes = {name: i for i, ns in names.items() for name in ns}
shard_proto = NamedSharding(
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_hlo_sharding(aval_out.ndim)
ns = NamedSharding(mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
shard_proto = aval_out.dtype._rules.physical_hlo_sharding(aval_out, shard_proto)
ns = aval_out.dtype._rules.physical_sharding(aval_out, ns)
aval_out = core.physical_aval(aval_out)
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
unspecified = set(range(aval_out.ndim)) if auto else set()
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto.to_proto(),
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto,
unspecified) # type: ignore

def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str:
Expand Down
9 changes: 0 additions & 9 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from jax._src.interpreters import pxla
from jax._src.internal_test_util import lax_test_util
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.util import NumpyComplexWarning

Expand Down Expand Up @@ -2989,14 +2988,6 @@ class FooTyRules:
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray((2,), jnp.dtype('uint32'))

@staticmethod
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding):
op_sharding_proto = hlo_sharding.to_proto()
new_op_sharding = op_sharding_proto.clone()
tad = list(new_op_sharding.tile_assignment_dimensions)
new_op_sharding.tile_assignment_dimensions = [*tad, 1]
return xc.HloSharding.from_proto(new_op_sharding)

@staticmethod
def logical_sharding(aval, phys_sharding):
return phys_sharding
Expand Down

0 comments on commit b887346

Please sign in to comment.