From cce43a37a25272e97513ef3be2233eac79e0f28e Mon Sep 17 00:00:00 2001 From: Justin Fu Date: Fri, 18 Oct 2024 15:45:41 -0700 Subject: [PATCH] fix x64 dtypes --- jax/_src/interpreters/mlir.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 2f10484bb380..0834692e88a2 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1003,8 +1003,6 @@ def _to_physical_op_sharding( if isinstance(aval, AbstractRef): return _to_physical_op_sharding(ctx, aval.inner_aval, sharding) assert isinstance(aval, (core.ShapedArray, core.DShapedArray)) - if dtypes.issubdtype(aval.dtype, dtypes.extended): - assert False, aval axis_ctx = ctx.axis_context if (isinstance(axis_ctx, sharding_impls.SPMDAxisContext) and axis_ctx.manual_axes): @@ -2419,9 +2417,7 @@ def wrap_with_layout_op(ctx: LoweringRuleContext, aval_in: core.AbstractValue): result_type = aval_to_ir_type(aval_out) assert isinstance(result_type, ir.Type), result_type - if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - assert False - out_shape = core.physical_aval(aval_out).shape # type: ignore + out_shape = aval_out.shape # type: ignore if core.is_constant_shape(out_shape): result_shapes = None else: @@ -2659,11 +2655,8 @@ def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None): return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get()) def _aval_to_default_layouts(aval): - if dtypes.issubdtype(aval.dtype, dtypes.extended): - assert False - avals = [core.physical_aval(aval)] # Row major order is default for `NumPy`. - return [list(range(aval.ndim - 1, -1, -1)) for aval in avals] + return [list(range(aval.ndim - 1, -1, -1))] def emit_python_callback(