Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove physical_hlo_sharding from TyRules. #20264

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 11 additions & 24 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,16 +819,19 @@ class LoweringResult(NamedTuple):
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]


def _to_logical_op_sharding(
def _to_physical_op_sharding(
aval: core.AbstractValue, sharding: XLACompatibleSharding | None,
) -> xc.HloSharding | None:
) -> xc.OpSharding | None:
if sharding is None:
return None
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
if isinstance(aval, AbstractRef):
return _to_logical_op_sharding(aval.inner_aval, sharding)
return _to_physical_op_sharding(aval.inner_aval, sharding)
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
return sharding._to_xla_hlo_sharding(aval.ndim)
if dtypes.issubdtype(aval.dtype, dtypes.extended):
sharding = aval.dtype._rules.physical_sharding(aval, sharding)
aval = core.physical_aval(aval)
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore


def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None:
Expand Down Expand Up @@ -941,13 +944,6 @@ def lower_jaxpr_to_module(
else:
dim_vars = ()

arg_op_shardings = (
map(_to_logical_op_sharding, jaxpr.in_avals, arg_shardings)
if arg_shardings is not None else arg_shardings)
result_op_shardings = (
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
if result_shardings is not None else result_shardings)

arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None
else in_layouts)
result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None
Expand Down Expand Up @@ -978,8 +974,8 @@ def lower_jaxpr_to_module(
replace_tokens_with_dummy=replace_tokens_with_dummy,
num_output_tokens=0,
replicated_args=replicated_args,
arg_shardings=arg_op_shardings,
result_shardings=result_op_shardings,
arg_shardings=arg_shardings,
result_shardings=result_shardings,
input_output_aliases=input_output_aliases,
xla_donated_args=xla_donated_args,
arg_names=arg_names,
Expand Down Expand Up @@ -1123,8 +1119,8 @@ def lower_jaxpr_to_fun(
public: bool = False,
replace_tokens_with_dummy: bool = False,
replicated_args: Sequence[bool] | None = None,
arg_shardings: Sequence[xc.HloSharding | None] | None = None,
result_shardings: Sequence[xc.HloSharding | None] | None = None,
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
use_sharding_annotations: bool = True,
input_output_aliases: Sequence[int | None] | None = None,
xla_donated_args: Sequence[bool] | None = None,
Expand Down Expand Up @@ -1483,15 +1479,6 @@ def wrap_with_memory_kind(
return op.result


def _to_physical_op_sharding(
aval: core.AbstractValue | None, sharding: xc.HloSharding | None
) -> xc.OpSharding | None:
if (isinstance(aval, core.ShapedArray) and dtypes.issubdtype(aval.dtype, dtypes.extended)
and sharding is not None):
return aval.dtype._rules.physical_hlo_sharding(aval, sharding).to_proto()
return None if sharding is None else sharding.to_proto() # type: ignore


def _emit_lowering_rule_as_fun(lowering_rule,
ctx: LoweringRuleContext) -> func_dialect.FuncOp:
"""Emits the contents of a lowering rule as a private function."""
Expand Down
4 changes: 0 additions & 4 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5110,10 +5110,6 @@ def handler(bufs):
return core.DArray(aval, phys_handler(bufs))
return handler

@staticmethod
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
return hlo_sharding

@staticmethod
def logical_sharding(aval, phys_sharding):
return phys_sharding
Expand Down
6 changes: 2 additions & 4 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,10 +1617,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,

func = mod_ctx.cached_primitive_lowerings.get(key, None)
if func is None:
arg_shardings = [None if is_unspecified(i) else i._to_xla_hlo_sharding(aval.ndim)
for aval, i in zip(ctx.avals_in, in_shardings)]
result_shardings = [None if is_unspecified(o) else o._to_xla_hlo_sharding(aval.ndim)
for aval, o in zip(ctx.avals_out, out_shardings)]
arg_shardings = [None if is_unspecified(i) else i for i in in_shardings]
result_shardings = [None if is_unspecified(o) else o for o in out_shardings]
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
# inputs or outputs because they are lost during MLIR->HLO conversion.
# using_sharding_annotation=False means we add an identity operation instead.
Expand Down
25 changes: 12 additions & 13 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,7 @@ def make_key_array_phys_sharding(aval, sharding):
else:
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
return GSPMDSharding(
sharding._device_assignment,
KeyTyRules.physical_hlo_sharding(aval, hlos))
sharding._device_assignment, physical_hlo_sharding(aval, hlos))


def get_logical_gspmd_sharding(aval, phys_sharding):
Expand All @@ -361,6 +360,17 @@ def get_logical_gspmd_sharding(aval, phys_sharding):
xc.HloSharding.from_proto(logical_op_sharding))


def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
key_shape = aval.dtype._impl.key_shape
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
tad = partitions + [1] * len(key_shape) + suffix
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)


class KeyTyRules:

@staticmethod
Expand All @@ -382,17 +392,6 @@ def physical_element_aval(dtype) -> core.ShapedArray:
def physical_const(val) -> Array:
return val._base_array

@staticmethod
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
key_shape = aval.dtype._impl.key_shape
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
tad = partitions + [1] * len(key_shape) + suffix
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)

@staticmethod
def physical_sharding(
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:
Expand Down
20 changes: 10 additions & 10 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,13 +567,13 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
aval_in, aval_out, x):
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
axes = {name: i for i, ns in names.items() for name in ns}
shard_proto = NamedSharding(
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_hlo_sharding(aval_in.ndim)
ns = NamedSharding(mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore
if dtypes.issubdtype(aval_in.dtype, dtypes.extended):
shard_proto = aval_in.dtype._rules.physical_hlo_sharding(aval_in, shard_proto)
ns = aval_in.dtype._rules.physical_sharding(aval_in, ns)
aval_in = core.physical_aval(aval_in)
shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto()
unspecified = set(range(aval_in.ndim)) if auto else set()
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto.to_proto(), # type: ignore
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, # type: ignore
unspecified_dims=unspecified)
return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, set())]

Expand All @@ -583,13 +583,13 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh)
sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=set())
axes = {name: i for i, ns in names.items() for name in ns}
shard_proto = NamedSharding(
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_hlo_sharding(aval_out.ndim)
ns = NamedSharding(mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
shard_proto = aval_out.dtype._rules.physical_hlo_sharding(aval_out, shard_proto)
ns = aval_out.dtype._rules.physical_sharding(aval_out, ns)
aval_out = core.physical_aval(aval_out)
shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto()
unspecified = set(range(aval_out.ndim)) if auto else set()
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto.to_proto(),
return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto,
unspecified) # type: ignore

def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str:
Expand Down
9 changes: 0 additions & 9 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from jax._src.interpreters import pxla
from jax._src.internal_test_util import lax_test_util
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.util import NumpyComplexWarning

Expand Down Expand Up @@ -2989,14 +2988,6 @@ class FooTyRules:
def physical_element_aval(dtype) -> core.ShapedArray:
return core.ShapedArray((2,), jnp.dtype('uint32'))

@staticmethod
def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding):
op_sharding_proto = hlo_sharding.to_proto()
new_op_sharding = op_sharding_proto.clone()
tad = list(new_op_sharding.tile_assignment_dimensions)
new_op_sharding.tile_assignment_dimensions = [*tad, 1]
return xc.HloSharding.from_proto(new_op_sharding)

@staticmethod
def logical_sharding(aval, phys_sharding):
return phys_sharding
Expand Down