Skip to content

Commit

Permalink
fix x64 dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
justinjfu committed Oct 19, 2024
1 parent 784c2d9 commit cce43a3
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,8 +1003,6 @@ def _to_physical_op_sharding(
if isinstance(aval, AbstractRef):
return _to_physical_op_sharding(ctx, aval.inner_aval, sharding)
assert isinstance(aval, (core.ShapedArray, core.DShapedArray))
if dtypes.issubdtype(aval.dtype, dtypes.extended):
assert False, aval
axis_ctx = ctx.axis_context
if (isinstance(axis_ctx, sharding_impls.SPMDAxisContext) and
axis_ctx.manual_axes):
Expand Down Expand Up @@ -2419,9 +2417,7 @@ def wrap_with_layout_op(ctx: LoweringRuleContext,
aval_in: core.AbstractValue):
result_type = aval_to_ir_type(aval_out)
assert isinstance(result_type, ir.Type), result_type
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
assert False
out_shape = core.physical_aval(aval_out).shape # type: ignore
out_shape = aval_out.shape # type: ignore
if core.is_constant_shape(out_shape):
result_shapes = None
else:
Expand Down Expand Up @@ -2659,11 +2655,8 @@ def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())

def _aval_to_default_layouts(aval):
if dtypes.issubdtype(aval.dtype, dtypes.extended):
assert False
avals = [core.physical_aval(aval)]
# Row major order is default for `NumPy`.
return [list(range(aval.ndim - 1, -1, -1)) for aval in avals]
return [list(range(aval.ndim - 1, -1, -1))]


def emit_python_callback(
Expand Down

0 comments on commit cce43a3

Please sign in to comment.