diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 93cef598b275..142c4df222c3 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -819,16 +819,19 @@ class LoweringResult(NamedTuple): _platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"] -def _to_logical_op_sharding( +def _to_physical_op_sharding( aval: core.AbstractValue, sharding: XLACompatibleSharding | None, -) -> xc.HloSharding | None: +) -> xc.OpSharding | None: if sharding is None: return None assert isinstance(sharding, sharding_impls.XLACompatibleSharding) if isinstance(aval, AbstractRef): - return _to_logical_op_sharding(aval.inner_aval, sharding) + return _to_physical_op_sharding(aval.inner_aval, sharding) assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) - return sharding._to_xla_hlo_sharding(aval.ndim) + if dtypes.issubdtype(aval.dtype, dtypes.extended): + sharding = aval.dtype._rules.physical_sharding(aval, sharding) + aval = core.physical_aval(aval) + return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None: @@ -941,13 +944,6 @@ def lower_jaxpr_to_module( else: dim_vars = () - arg_op_shardings = ( - map(_to_logical_op_sharding, jaxpr.in_avals, arg_shardings) - if arg_shardings is not None else arg_shardings) - result_op_shardings = ( - map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings) - if result_shardings is not None else result_shardings) - arg_layouts = (map(_to_xla_layout, in_layouts) if in_layouts is not None else in_layouts) result_layouts = (map(_to_xla_layout, out_layouts) if out_layouts is not None @@ -978,8 +974,8 @@ def lower_jaxpr_to_module( replace_tokens_with_dummy=replace_tokens_with_dummy, num_output_tokens=0, replicated_args=replicated_args, - arg_shardings=arg_op_shardings, - result_shardings=result_op_shardings, + arg_shardings=arg_shardings, + result_shardings=result_shardings, input_output_aliases=input_output_aliases, xla_donated_args=xla_donated_args, arg_names=arg_names, @@ -1123,8 +1119,8 @@ def lower_jaxpr_to_fun( public: bool = False, replace_tokens_with_dummy: bool = False, replicated_args: Sequence[bool] | None = None, - arg_shardings: Sequence[xc.HloSharding | None] | None = None, - result_shardings: Sequence[xc.HloSharding | None] | None = None, + arg_shardings: Sequence[XLACompatibleSharding | None] | None = None, + result_shardings: Sequence[XLACompatibleSharding | None] | None = None, use_sharding_annotations: bool = True, input_output_aliases: Sequence[int | None] | None = None, xla_donated_args: Sequence[bool] | None = None, @@ -1483,15 +1479,6 @@ def wrap_with_memory_kind( return op.result -def _to_physical_op_sharding( - aval: core.AbstractValue | None, sharding: xc.HloSharding | None -) -> xc.OpSharding | None: - if (isinstance(aval, core.ShapedArray) and dtypes.issubdtype(aval.dtype, dtypes.extended) - and sharding is not None): - return aval.dtype._rules.physical_hlo_sharding(aval, sharding).to_proto() - return None if sharding is None else sharding.to_proto() # type: ignore - - def _emit_lowering_rule_as_fun(lowering_rule, ctx: LoweringRuleContext) -> func_dialect.FuncOp: """Emits the contents of a lowering rule as a private function.""" diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 0af178467e4a..cd8bc13263c6 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -5110,10 +5110,6 @@ def handler(bufs): return core.DArray(aval, phys_handler(bufs)) return handler - @staticmethod - def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: - return hlo_sharding - @staticmethod def logical_sharding(aval, phys_sharding): return phys_sharding diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 9d07b963d01b..0d2b4d184c9e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1617,10 +1617,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings, func = mod_ctx.cached_primitive_lowerings.get(key, None) if func is None: - arg_shardings = [None if is_unspecified(i) else i._to_xla_hlo_sharding(aval.ndim) - for aval, i in zip(ctx.avals_in, in_shardings)] - result_shardings = [None if is_unspecified(o) else o._to_xla_hlo_sharding(aval.ndim) - for aval, o in zip(ctx.avals_out, out_shardings)] + arg_shardings = [None if is_unspecified(i) else i for i in in_shardings] + result_shardings = [None if is_unspecified(o) else o for o in out_shardings] # TODO(b/228598865): inlined calls cannot have shardings set directly on the # inputs or outputs because they are lost during MLIR->HLO conversion. # using_sharding_annotation=False means we add an identity operation instead. diff --git a/jax/_src/prng.py b/jax/_src/prng.py index db6d174687a3..00d3b1bf7d38 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -342,8 +342,7 @@ def make_key_array_phys_sharding(aval, sharding): else: hlos = sharding._to_xla_hlo_sharding(aval.ndim) return GSPMDSharding( - sharding._device_assignment, - KeyTyRules.physical_hlo_sharding(aval, hlos)) + sharding._device_assignment, physical_hlo_sharding(aval, hlos)) def get_logical_gspmd_sharding(aval, phys_sharding): @@ -361,6 +360,17 @@ def get_logical_gspmd_sharding(aval, phys_sharding): xc.HloSharding.from_proto(logical_op_sharding)) +def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: + key_shape = aval.dtype._impl.key_shape + new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore + partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( + hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + tad = partitions + [1] * len(key_shape) + suffix + new_op_sharding.tile_assignment_dimensions = tad + return xc.HloSharding.from_proto(new_op_sharding) + + class KeyTyRules: @staticmethod @@ -382,17 +392,6 @@ def physical_element_aval(dtype) -> core.ShapedArray: def physical_const(val) -> Array: return val._base_array - @staticmethod - def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: - key_shape = aval.dtype._impl.key_shape - new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore - partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( - hlo_sharding) - suffix = [] if num_replicas == 1 else [num_replicas] - tad = partitions + [1] * len(key_shape) + suffix - new_op_sharding.tile_assignment_dimensions = tad - return xc.HloSharding.from_proto(new_op_sharding) - @staticmethod def physical_sharding( aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding: diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index f33aa89382ea..f2ae274a6570 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -567,13 +567,13 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names, aval_in, aval_out, x): manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) axes = {name: i for i, ns in names.items() for name in ns} - shard_proto = NamedSharding( - mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore - )._to_xla_hlo_sharding(aval_in.ndim) + ns = NamedSharding(mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore if dtypes.issubdtype(aval_in.dtype, dtypes.extended): - shard_proto = aval_in.dtype._rules.physical_hlo_sharding(aval_in, shard_proto) + ns = aval_in.dtype._rules.physical_sharding(aval_in, ns) + aval_in = core.physical_aval(aval_in) + shard_proto = ns._to_xla_hlo_sharding(aval_in.ndim).to_proto() unspecified = set(range(aval_in.ndim)) if auto else set() - sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto.to_proto(), # type: ignore + sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, shard_proto, # type: ignore unspecified_dims=unspecified) return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, set())] @@ -583,13 +583,13 @@ def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names, manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names) - auto, mesh) sx = mlir.wrap_with_sharding_op(ctx, x, aval_in, manual_proto, unspecified_dims=set()) axes = {name: i for i, ns in names.items() for name in ns} - shard_proto = NamedSharding( - mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore - )._to_xla_hlo_sharding(aval_out.ndim) + ns = NamedSharding(mesh, sharding_impls.array_mapping_to_axis_resources(axes)) # type: ignore if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - shard_proto = aval_out.dtype._rules.physical_hlo_sharding(aval_out, shard_proto) + ns = aval_out.dtype._rules.physical_sharding(aval_out, ns) + aval_out = core.physical_aval(aval_out) + shard_proto = ns._to_xla_hlo_sharding(aval_out.ndim).to_proto() unspecified = set(range(aval_out.ndim)) if auto else set() - return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto.to_proto(), + return mlir.wrap_with_shard_to_full_op(ctx, sx, aval_out, shard_proto, unspecified) # type: ignore def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str: diff --git a/tests/lax_test.py b/tests/lax_test.py index 74a36e6e2f00..68bda5a5394c 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -45,7 +45,6 @@ from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal -from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version from jax._src.util import NumpyComplexWarning @@ -2989,14 +2988,6 @@ class FooTyRules: def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray((2,), jnp.dtype('uint32')) - @staticmethod - def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding): - op_sharding_proto = hlo_sharding.to_proto() - new_op_sharding = op_sharding_proto.clone() - tad = list(new_op_sharding.tile_assignment_dimensions) - new_op_sharding.tile_assignment_dimensions = [*tad, 1] - return xc.HloSharding.from_proto(new_op_sharding) - @staticmethod def logical_sharding(aval, phys_sharding): return phys_sharding