diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index b62110dff317..e775637d81e8 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -240,20 +240,13 @@ def ir_constant(x, mlir_type=None): if not mlir_type: mlir_type = _dtype_to_ir_type(x.dtype) if isinstance(x, int) or x.dtype in (np.int32, np.uint32, np.int8): - return arith.ConstantOp(mlir_type, ir.IntegerAttr.get(mlir_type, int(x)) - ).result + return arith.constant(mlir_type, ir.IntegerAttr.get(mlir_type, int(x))) elif isinstance(x, float) or x.dtype == np.float32: - return arith.ConstantOp( - mlir_type, ir.FloatAttr.get(mlir_type, float(x)) - ).result + return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x))) elif x.dtype == jnp.bfloat16: - return arith.ConstantOp( - mlir_type, ir.FloatAttr.get(mlir_type, float(x)) - ).result + return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x))) elif x.dtype == jnp.bool_: - return arith.ConstantOp( - mlir_type, ir.BoolAttr.get(bool(x)) - ).result + return arith.constant(mlir_type, ir.BoolAttr.get(bool(x))) raise NotImplementedError(x.dtype) @@ -942,7 +935,7 @@ def _make_index(s): return ir_constant(s, ir.IndexType.get()) if s.type == ir.IndexType.get(): return s - return arith.IndexCastOp(ir.IndexType.get(), s).result + return arith.index_cast(ir.IndexType.get(), s) def _maybe_cast_to_index(cast_to_index, x): @@ -1043,7 +1036,7 @@ def _slice_memref( _dtype_to_ir_type(ref_dtype), memory_space=ref.type.memory_space, ) - out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, dynamic_sizes).result + out = tpu.memref_slice(target_ref_ty, ref, starts, dynamic_sizes) if any(squeeze_dims): # We need to squeeze out some dimensions static_sizes = tuple(s if not isinstance(s, ir.Value) @@ -1053,7 +1046,7 @@ def _slice_memref( _dtype_to_ir_type(ref_dtype), memory_space=ref.type.memory_space, ) - out = tpu.MemRefSqueezeOp(squeezed_ref_ty, out).result + out = tpu.memref_squeeze(squeezed_ref_ty, out) return out, ref_block_shape @@ -1210,8 +1203,7 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): if is_smem_load: if ctx.avals_out[0].shape: raise ValueError("Can only load scalars from SMEM") - return _maybe_cast_load_to_bool( - aval_out, memref.LoadOp(ref, starts).result) + return _maybe_cast_load_to_bool(aval_out, memref.load(ref, starts)) elif str(ref_type.memory_space) != "#tpu.memory_space": extra = "" if str(ref_type.memory_space) == "#tpu.memory_space": @@ -1221,17 +1213,17 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): ) load_aval = jax_core.ShapedArray(sizes, dtype=aval_out.dtype) if need_stride: - load_val = tpu.StridedLoadOp( + load_val = tpu.strided_load( aval_to_ir_type(load_aval, is_kernel_boundary=True), ref, starts, strides - ).result + ) else: - load_val = vector.LoadOp( - aval_to_ir_type(load_aval, is_kernel_boundary=True), ref, starts).result + load_val = vector.load( + aval_to_ir_type(load_aval, is_kernel_boundary=True), ref, starts) if load_aval != aval_out: vec_type = ir.VectorType.get(aval_out.shape, _dtype_to_ir_type(aval_out.dtype, is_kernel_boundary=True)) - load_val = vector.ShapeCastOp(vec_type, load_val).result + load_val = vector.shape_cast(vec_type, load_val) return _maybe_cast_load_to_bool(aval_out, load_val) def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle: @@ -1262,7 +1254,7 @@ def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree ref_block_shape, cast_to_index=True, ) - load_ops.append(memref.LoadOp(ref, starts).result) + load_ops.append(memref.load(ref, starts)) return KeyScalarBundle(scalars=load_ops, key_shape=tuple(ref_block_shape)) @@ -1296,10 +1288,10 @@ def _maybe_cast_load_to_bool( load_vector_type, ir.DenseElementsAttr.get_splat(load_vector_type, const_zero) ) - return arith.CmpIOp(predicate, val, vector_zeros).result + return arith.cmpi(predicate, val, vector_zeros) else: # Scalar case. const_zero = arith.ConstantOp(load_scalar_type, const_zero) - return arith.CmpIOp(predicate, val, const_zero).result + return arith.cmpi(predicate, val, const_zero) def _maybe_cast_store_to_memref_type( @@ -1308,7 +1300,7 @@ def _maybe_cast_store_to_memref_type( if expected_aval.dtype != jnp.bool_: return val int_out_type = aval_to_ir_type(expected_aval, is_kernel_boundary=True) - return arith.ExtUIOp(int_out_type, val).result + return arith.extui(int_out_type, val) def _masked_swap_lowering_rule( @@ -1353,7 +1345,7 @@ def _masked_swap_lowering_rule( if is_smem_store: if val_aval.shape: raise ValueError("Can only store scalars to SMEM") - result = memref.LoadOp(ref, starts).result + result = memref.load(ref, starts) result = _maybe_cast_load_to_bool(val_aval, result) val = _maybe_cast_store_to_memref_type(val_aval, val) memref.StoreOp(val, ref, starts) @@ -1384,18 +1376,18 @@ def _masked_swap_lowering_rule( mem_aval_vec_type = ir.VectorType.get(mem_aval.shape, _dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True)) if need_stride: - result = tpu.StridedLoadOp(mem_aval_vec_type, ref, starts, strides).result + result = tpu.strided_load(mem_aval_vec_type, ref, starts, strides) else: - result = vector.LoadOp(mem_aval_vec_type, ref, starts).result + result = vector.load(mem_aval_vec_type, ref, starts) val = _maybe_cast_store_to_memref_type(val_aval, val) if mem_aval != aval_out: # We are slicing a scalar so provided dummy 1 indices result_vec_type = ir.VectorType.get(aval_out.shape, _dtype_to_ir_type(aval_out.dtype, is_kernel_boundary=True)) - result = vector.ShapeCastOp(result_vec_type, result).result + result = vector.shape_cast(result_vec_type, result) val_vec_type = ir.VectorType.get(mem_aval.shape, _dtype_to_ir_type(mem_aval.dtype, is_kernel_boundary=True)) - val = vector.ShapeCastOp(val_vec_type, val).result + val = vector.shape_cast(val_vec_type, val) result = _maybe_cast_load_to_bool(val_aval, result) if need_stride: @@ -1452,13 +1444,7 @@ def _proxy_fun(val, *, axes): out_type = aval_to_ir_type(ctx.avals_out[0]) identity = ir.DenseElementsAttr.get_splat(out_type, val) acc = arith.ConstantOp(out_type, identity) - op = vector.MultiDimReductionOp( - kind, - x, - acc, - axes, - ) - return op.result + return vector.multi_reduction(kind, x, acc, axes) return _lowering_rule @@ -1562,13 +1548,13 @@ def _proxy_fun(val, *, shape, broadcast_dimensions): out_type = ir.VectorType.get( out_shape, _dtype_to_ir_type(aval_out.dtype) ) - val = vector.ShapeCastOp(out_type, val).result + val = vector.shape_cast(out_type, val) if out_shape == aval_out.shape: return val out_type = ir.VectorType.get( aval_out.shape, _dtype_to_ir_type(aval_out.dtype) ) - return vector.BroadcastOp(out_type, val).result + return vector.broadcast(out_type, val) lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule @@ -1623,7 +1609,7 @@ def _dot_general_lowering_rule( acc, [1] ) - return vector.ShapeCastOp(out_type, red).result + return vector.shape_cast(out_type, red) if lhs_dims == (1,): transpose_lhs = False @@ -1652,12 +1638,11 @@ def _dot_general_lowering_rule( out_tile = arith.ConstantOp( out_type, ir.DenseElementsAttr.get_splat(out_type, val) ) - op = tpu.MatmulOp( + return tpu.matmul( out_type, x, y, out_tile, transpose_lhs=transpose_lhs, transpose_rhs=transpose_rhs, precision=precision_attr ) - return op.result lowering_rules[lax.dot_general_p] = _dot_general_lowering_rule @@ -1717,27 +1702,27 @@ def _convert_element_type_lowering_rule( new_dtype, jnp.floating ): if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4: - return arith.ExtFOp(out_type, x).result + return arith.extf(out_type, x) elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4: - return arith.TruncFOp(out_type, x).result + return arith.truncf(out_type, x) elif jnp.issubdtype(old_dtype, jnp.integer) and jnp.issubdtype( new_dtype, jnp.integer ): if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4: - return arith.ExtSIOp(out_type, x).result + return arith.extsi(out_type, x) elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4: - return arith.TruncIOp(out_type, x).result + return arith.trunci(out_type, x) elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits: # This case triggers when casting signed to unsigned or vice versa. return x elif jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype( new_dtype, jnp.signedinteger ) and old_dtype.itemsize == new_dtype.itemsize == 4: - return arith.FPToSIOp(out_type, x).result + return arith.fptosi(out_type, x) elif jnp.issubdtype(old_dtype, jnp.integer) and jnp.issubdtype( new_dtype, jnp.floating ) and old_dtype.itemsize == new_dtype.itemsize == 4: - return arith.SIToFPOp(out_type, x).result + return arith.sitofp(out_type, x) elif ( old_dtype == jnp.bool_ and jnp.issubdtype(new_dtype, jnp.integer) @@ -1784,8 +1769,8 @@ def _reshape_lowering_rule(ctx: LoweringRuleContext, x, new_sizes, dimensions): if any(d is None for d in new_sizes): raise NotImplementedError if not ctx.avals_in[0].shape: - return vector.BroadcastOp(aval_to_ir_type(ctx.avals_out[0]), x).result - return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result + return vector.broadcast(aval_to_ir_type(ctx.avals_out[0]), x) + return vector.shape_cast(aval_to_ir_type(ctx.avals_out[0]), x) lowering_rules[lax.reshape_p] = _reshape_lowering_rule @@ -1802,17 +1787,16 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): f" but got: {aval_out.dtype}. Try casting the input before squeezing" " the scalar." ) - return vector.ExtractOp(x, [], [0] * len(aval_in.shape)).result - return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result + return vector.extract(x, [], [0] * len(aval_in.shape)) + return vector.shape_cast(aval_to_ir_type(ctx.avals_out[0]), x) lowering_rules[lax.squeeze_p] = _squeeze_lowering_rule def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): - return tpu.ConcatenateOp( - aval_to_ir_type(ctx.avals_out[0]), xs, dimension=dimension - ).result + out_type = aval_to_ir_type(ctx.avals_out[0]) + return tpu.concatenate(out_type, xs, dimension=dimension) lowering_rules[lax.concatenate_p] = _concatenate_lowering_rule @@ -1821,7 +1805,7 @@ def _concatenate_lowering_rule(ctx: LoweringRuleContext, *xs, dimension): def _iota_lowering_rule(ctx: LoweringRuleContext, dtype, shape, dimension, sharding): out_type = aval_to_ir_type(ctx.avals_out[0]) - return tpu.IotaOp(out_type, dimension=dimension).result + return tpu.iota(out_type, dimension=dimension) lowering_rules[lax.iota_p] = _iota_lowering_rule @@ -1831,7 +1815,7 @@ def _transpose_lowering_rule(ctx: LoweringRuleContext, x, *, permutation): if permutation != (1, 0): raise NotImplementedError out_type = aval_to_ir_type(ctx.avals_out[0]) - return vector.TransposeOp(out_type, x, permutation).result + return vector.transpose(out_type, x, permutation) lowering_rules[lax.transpose_p] = _transpose_lowering_rule @@ -1870,9 +1854,9 @@ def _add_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.integer): - return arith.AddIOp(x, y).result + return arith.addi(x, y) if jnp.issubdtype(aval_out.dtype, jnp.floating): - return arith.AddFOp(x, y).result + return arith.addf(x, y) raise NotImplementedError(aval_out.dtype) @@ -1886,11 +1870,11 @@ def _max_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.signedinteger): - return arith.MaxSIOp(x, y).result + return arith.maxsi(x, y) elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger): - return arith.MaxUIOp(x, y).result + return arith.maxui(x, y) elif jnp.issubdtype(aval_out.dtype, jnp.floating): - return arith.MaximumFOp(x, y).result + return arith.maximumf(x, y) raise NotImplementedError(aval_out.dtype) @@ -1902,11 +1886,11 @@ def _min_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.signedinteger): - return arith.MinSIOp(x, y).result + return arith.minsi(x, y) elif jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger): - return arith.MinUIOp(x, y).result + return arith.minui(x, y) elif jnp.issubdtype(aval_out.dtype, jnp.floating): - return arith.MinimumFOp(x, y).result + return arith.minimumf(x, y) raise NotImplementedError(aval_out.dtype) @@ -1918,9 +1902,9 @@ def _sub_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.integer): - return arith.SubIOp(x, y).result + return arith.subi(x, y) if jnp.issubdtype(aval_out.dtype, jnp.floating): - return arith.SubFOp(x, y).result + return arith.subf(x, y) raise NotImplementedError(aval_out.dtype) @@ -1932,9 +1916,9 @@ def _mul_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.integer): - return arith.MulIOp(x, y).result + return arith.muli(x, y) if jnp.issubdtype(aval_out.dtype, jnp.floating): - return arith.MulFOp(x, y).result + return arith.mulf(x, y) raise NotImplementedError(aval_out.dtype) @@ -1946,11 +1930,11 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.integer): - return arith.DivSIOp(x, y).result + return arith.divsi(x, y) if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger): - return arith.DivUIOp(x, y).result + return arith.divui(x, y) elif jnp.issubdtype(aval_out.dtype, jnp.floating): - return arith.DivFOp(x, y).result + return arith.divf(x, y) raise NotImplementedError(aval_out.dtype) @@ -1962,11 +1946,11 @@ def _rem_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.integer): - return arith.RemSIOp(x, y).result + return arith.remsi(x, y) if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger): - return arith.RemUIOp(x, y).result + return arith.remui(x, y) elif jnp.issubdtype(aval_out.dtype, jnp.floating): - return arith.RemFOp(x, y).result + return arith.remf(x, y) raise NotImplementedError(aval_out.dtype) @@ -1977,9 +1961,9 @@ def _rem_lowering_rule(ctx: LoweringRuleContext, x, y): def _abs_lowering_rule(ctx: LoweringRuleContext, x): (aval_out,) = ctx.avals_out if jnp.issubdtype(aval_out.dtype, jnp.integer): - return math.AbsIOp(x).result + return math.absi(x) if jnp.issubdtype(aval_out.dtype, jnp.floating): - return math.AbsFOp(x).result + return math.absf(x) raise NotImplementedError(aval_out.dtype) @@ -2009,21 +1993,21 @@ def _sign_lowering_rule(ctx: LoweringRuleContext, x): def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): - return math.RsqrtOp(x).result + return math.rsqrt(x) lowering_rules[lax.rsqrt_p] = _rsqrt_lowering_rule def _sqrt_lowering_rule(ctx: LoweringRuleContext, x): - return math.SqrtOp(x).result + return math.sqrt(x) lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule def _exp_lowering_rule(ctx: LoweringRuleContext, x): - return math.ExpOp(x).result + return math.exp(x) lowering_rules[lax.exp_p] = _exp_lowering_rule @@ -2031,9 +2015,9 @@ def _exp_lowering_rule(ctx: LoweringRuleContext, x): def _pow_lowering_rule(ctx: LoweringRuleContext, x, y): if not isinstance(x, ir.Value) and x == 2.: - return math.Exp2Op(y).result + return math.exp2(y) x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0]) - return math.PowFOp(x, y).result + return math.powf(x, y) lowering_rules[lax.pow_p] = _pow_lowering_rule @@ -2060,58 +2044,58 @@ def _exp2_lowering_rule(ctx: LoweringRuleContext, x): def _logistic_lowering_rule(ctx: LoweringRuleContext, x): - neg_x = arith.NegFOp(x).result - exp_neg_x = math.ExpOp(neg_x).result + neg_x = arith.negf(x) + exp_neg_x = math.exp(neg_x) aval_out = ctx.avals_out[0] out_type = aval_to_ir_type(aval_out) if aval_out.shape == (): one = ir_constant(1.0, mlir_type=out_type) else: one = vector.BroadcastOp(out_type, ir_constant(1.0)) - denom = arith.AddFOp(one, exp_neg_x).result - return arith.DivFOp(one, denom).result + denom = arith.addf(one, exp_neg_x) + return arith.divf(one, denom) lowering_rules[lax.logistic_p] = _logistic_lowering_rule def _sin_lowering_rule(ctx: LoweringRuleContext, x): - return math.SinOp(x).result + return math.sin(x) lowering_rules[lax.sin_p] = _sin_lowering_rule def _cos_lowering_rule(ctx: LoweringRuleContext, x): - return math.CosOp(x).result + return math.cos(x) lowering_rules[lax.cos_p] = _cos_lowering_rule def _tan_lowering_rule(ctx: LoweringRuleContext, x): - return math.TanOp(x).result + return math.tan(x) lowering_rules[lax.tan_p] = _tan_lowering_rule def _tanh_lowering_rule(ctx: LoweringRuleContext, x): - return math.TanhOp(x).result + return math.tanh(x) lowering_rules[lax.tanh_p] = _tanh_lowering_rule def _log_lowering_rule(ctx: LoweringRuleContext, x): - return math.LogOp(x).result + return math.log(x) lowering_rules[lax.log_p] = _log_lowering_rule def _log1p_lowering_rule(ctx: LoweringRuleContext, x): - return math.Log1pOp(x).result + return math.log1p(x) lowering_rules[lax.log1p_p] = _log1p_lowering_rule @@ -2119,9 +2103,9 @@ def _log1p_lowering_rule(ctx: LoweringRuleContext, x): def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method): if rounding_method == 0: - return math.RoundOp(x).result + return math.round(x) elif rounding_method == 1: - return math.RoundEvenOp(x).result + return math.roundeven(x) else: raise NotImplementedError(f"Unsupported rounding method: {rounding_method}") @@ -2130,21 +2114,21 @@ def _round_lowering_rule(ctx: LoweringRuleContext, x, *, rounding_method): def _ceil_lowering_rule(ctx: LoweringRuleContext, x): - return math.CeilOp(x).result + return math.ceil(x) lowering_rules[lax.ceil_p] = _ceil_lowering_rule def _floor_lowering_rule(ctx: LoweringRuleContext, x): - return math.FloorOp(x).result + return math.floor(x) lowering_rules[lax.floor_p] = _floor_lowering_rule def _clz_lowering_rule(ctx: LoweringRuleContext, x): - return math.CountLeadingZerosOp(x).result + return math.ctlz(x) lowering_rules[lax.clz_p] = _clz_lowering_rule @@ -2153,7 +2137,7 @@ def _population_count_lowering_rule(ctx: LoweringRuleContext, x): aval_out = ctx.avals_out[0] if aval_out.shape == (): raise ValueError("Population count is not supported on scalars") - return math.CtPopOp(x).result + return math.ctpop(x) lowering_rules[lax.population_count_p] = _population_count_lowering_rule @@ -2268,7 +2252,7 @@ def _cmp_lowering_rule(prim, ctx: LoweringRuleContext, x, y): def _and_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) - return arith.AndIOp(x, y).result + return arith.andi(x, y) lowering_rules[lax.and_p] = _and_lowering_rule @@ -2277,7 +2261,7 @@ def _and_lowering_rule(ctx: LoweringRuleContext, x, y): def _or_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) - return arith.OrIOp(x, y).result + return arith.ori(x, y) lowering_rules[lax.or_p] = _or_lowering_rule @@ -2303,7 +2287,7 @@ def _not_lowering_rule(ctx: LoweringRuleContext, x): minus_one = arith.ConstantOp( out_type, ir.DenseElementsAttr.get_splat(out_type, scalar_minus_one) ) - return arith.XOrIOp(x, minus_one).result + return arith.xori(x, minus_one) lowering_rules[lax.not_p] = _not_lowering_rule @@ -2324,7 +2308,7 @@ def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, x, *args): return x # Assume x and y, which we check above. y, = args - return arith.SelectOp(pred, y, x).result + return arith.select(pred, y, x) lowering_rules[lax.select_n_p] = _select_n_lowering_rule @@ -2563,9 +2547,9 @@ def _while_lowering_rule( def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): index, *args = args out_types = map(aval_to_ir_type, ctx.avals_out) - pred = arith.CmpIOp( + pred = arith.cmpi( arith.CmpIPredicate.ne, index, ir_constant(0, index.type) - ).result + ) if_op = scf.IfOp(pred, out_types, hasElse=True) lowering_context = ctx.lowering_context.replace( block_shapes=ctx.block_shapes[1:], @@ -2576,7 +2560,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, *args, branches): if len(branches) > 2: out = _cond_lowering_rule( ctx, - arith.SubIOp(index, ir_constant(1, index.type)).result, + arith.subi(index, ir_constant(1, index.type)), *args, branches=branches[1:], ) @@ -2662,7 +2646,7 @@ def _num_programs_lowering_rule(ctx: LoweringRuleContext, *, axis: int): def _repeat_lowering_rule(ctx: LoweringRuleContext, x, *, repeats, axis): (out_aval,) = ctx.avals_out - return tpu.RepeatOp(aval_to_ir_type(out_aval), x, axis, repeats).result + return tpu.repeat(aval_to_ir_type(out_aval), x, axis, repeats) lowering_rules[tpu_primitives.repeat_p] = _repeat_lowering_rule @@ -2672,14 +2656,14 @@ def _roll_lowering_rule( ctx: LoweringRuleContext, x, shift, *, axis, stride, stride_axis ): (out_aval,) = ctx.avals_out - return tpu.DynamicRotateOp( + return tpu.dynamic_rotate( aval_to_ir_type(out_aval), x, shift, axis, stride=stride, stride_dimension=stride_axis, - ).result + ) lowering_rules[tpu_primitives.roll_p] = _roll_lowering_rule @@ -2690,13 +2674,13 @@ def _slice_lowering_rule( ): """Lowers a slice to vector dialect.""" (aval_out,) = ctx.avals_out + out_type = aval_to_ir_type(aval_out) if strides is None: strides = [1] * len(start_indices) sizes = np.array(limit_indices) - np.array(start_indices) - op = vector.ExtractStridedSliceOp( - aval_to_ir_type(aval_out), x, start_indices, sizes, strides + return vector.extract_strided_slice( + out_type, x, start_indices, sizes, strides ) - return op.result lowering_rules[lax.slice_p] = _slice_lowering_rule @@ -2704,7 +2688,7 @@ def _slice_lowering_rule( def _xor_lowering_rule(ctx: LoweringRuleContext, x, y): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) - return arith.XOrIOp(x, y).result + return arith.xori(x, y) lowering_rules[lax.xor_p] = _xor_lowering_rule @@ -2713,7 +2697,7 @@ def _xor_lowering_rule(ctx: LoweringRuleContext, x, y): def _shift_left_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) - return arith.ShLIOp(x, d).result + return arith.shli(x, d) lowering_rules[lax.shift_left_p] = _shift_left_lowering_rule @@ -2722,7 +2706,7 @@ def _shift_left_lowering_rule(ctx: LoweringRuleContext, x, d): def _shift_right_arithmetic_lowering_rule(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) - return arith.ShRSIOp(x, d).result + return arith.shrsi(x, d) lowering_rules[lax.shift_right_arithmetic_p] = _shift_right_arithmetic_lowering_rule @@ -2731,7 +2715,7 @@ def _shift_right_arithmetic_lowering_rule(ctx: LoweringRuleContext, x, d): def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d): x, d = _bcast(x, d, *ctx.avals_in, *ctx.avals_out) - return arith.ShRUIOp(x, d).result + return arith.shrui(x, d) lowering_rules[lax.shift_right_logical_p] = _shift_right_logical_lowering_rules @@ -2750,7 +2734,7 @@ def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x): def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty): del ty (out_aval,) = ctx.avals_out - return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result + return tpu.bitcast(aval_to_ir_type(out_aval), x) lowering_rules[tpu_primitives.bitcast_p] = _bitcast_lowering_rule @@ -2762,7 +2746,7 @@ def _bitcast_convert_type_lowering_rule( new_bitwidth = pallas_utils.dtype_bitwidth(new_dtype) if old_bitwidth != new_bitwidth: raise NotImplementedError("Changing bitwidths not supported.") - return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result + return tpu.bitcast(aval_to_ir_type(out_aval), x) lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value: @@ -2771,16 +2755,16 @@ def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value: if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype): assert aval.memory_space == TPUMemorySpace.SEMAPHORE memref_type = aval_to_ir_type(aval, memory_space=TPUMemorySpace.SEMAPHORE) - return tpu.AllocaSemaphoreOp(memref_type).result + return tpu.sem_alloc(memref_type) else: out_type = ir.MemRefType.get( aval.shape, _dtype_to_ir_type(aval.dtype, is_kernel_boundary=True), memory_space=memspace) - return memref.AllocaOp(out_type, [], []).result + return memref.alloca(out_type, [], []) elif isinstance(aval, tpu_core.AbstractSemaphore): memref_type = aval_to_ir_type(aval, memory_space=TPUMemorySpace.SEMAPHORE) - return tpu.AllocaSemaphoreOp(memref_type).result + return tpu.sem_alloc(memref_type) raise NotImplementedError(f"Cannot allocate {type(aval)}.") @@ -2834,7 +2818,7 @@ def _semaphore_read_lowering_rule( sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, transforms = tree_util.tree_unflatten(args_tree, args) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) - return tpu.SemaphoreReadOp(sem).result + return tpu.sem_read(sem) lowering_rules[tpu_primitives.semaphore_read_p] = _semaphore_read_lowering_rule @@ -2852,9 +2836,8 @@ def _semaphore_signal_lowering_rule( sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) - return tpu.SemaphoreSignalOp( - sem, value, device_id=device_id, core_id=core_index - ).results + tpu.sem_signal(sem, value, device_id=device_id, core_id=core_index) + return [] lowering_rules[tpu_primitives.semaphore_signal_p] = ( @@ -2865,7 +2848,8 @@ def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) sem, transforms, value = tree_util.tree_unflatten(args_tree, args) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) - return tpu.SemaphoreWaitOp(sem, value).results + tpu.sem_wait(sem, value) + return [] lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, @@ -2901,8 +2885,9 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) - return tpu.EnqueueDMAOp(src_ref, dst_ref, sem, source_semaphore=src_sem, - device_id=device_id).results + tpu.enqueue_dma(src_ref, dst_ref, sem, source_semaphore=src_sem, + device_id=device_id) + return [] lowering_rules[tpu_primitives.dma_start_p] = _dma_start_lowering_rule @@ -2917,11 +2902,12 @@ def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, ref_block_shape = block_shapes[2] ref, _ = _transform_ref(ref, ref_aval.dtype, ref_block_shape, transforms) sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) - return tpu.WaitDMAOp(sem, ref).results + tpu.wait_dma(sem, ref) + return [] lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule def _device_id_lowering_rule(ctx: LoweringRuleContext): - return tpu.DeviceIdOp().result + return tpu.device_id() lowering_rules[tpu_primitives.device_id_p] = _device_id_lowering_rule def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): @@ -2930,7 +2916,7 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): # We are querying a named axis corresponding to a grid dimension. return _program_id_lowering_rule(ctx, axis=grid_names.index(axis_name)) # We are querying a named axis corresponding to a mesh dimension. - device_id = tpu.DeviceIdOp().result + device_id = tpu.device_id() mesh_context = ctx.lowering_context.mesh_context if mesh_context is None: raise ValueError("Mesh context is not set.") @@ -2946,12 +2932,13 @@ def _axis_index_rule(ctx: LoweringRuleContext, *, axis_name: Hashable): def _get_barrier_semaphore_rule(ctx: LoweringRuleContext): memref_type = aval_to_ir_type(ctx.avals_out[0]) - return tpu.GetBarrierSemaphoreOp(memref_type).result + return tpu.sem_barrier(memref_type) lowering_rules[tpu_primitives.get_barrier_semaphore_p] = _get_barrier_semaphore_rule def _delay_rule(ctx: LoweringRuleContext, nanos: int): - return tpu.DelayOp(nanos).results + tpu.delay(nanos) + return [] lowering_rules[tpu_primitives.delay_p] = _delay_rule @@ -2994,14 +2981,16 @@ def _prng_seed_lowering_rule(ctx: LoweringRuleContext, *seeds): # In the KeyScalarBundle case we unpack the bundle and set the seed with # the list of scalars. if len(seeds) == 1 and isinstance(seeds[0], KeyScalarBundle): - return tpu.PRNGSeed32Op(seeds[0].scalars).results + tpu.prng_set_seed_32(seeds[0].scalars) + return [] # For integer seeds, we can set the seed directly as PRNGSeed32Op natively # takes in a list of integers as input. all_integers = all(isinstance(seed.type, ir.IntegerType) for seed in seeds) if not all_integers: seed_types = [seed.type for seed in seeds] raise ValueError(f"All seed data must be scalar integers. Got {seed_types}") - return tpu.PRNGSeed32Op(seeds).results + tpu.prng_set_seed_32(seeds) + return [] lowering_rules[tpu_primitives.prng_seed_p] = _prng_seed_lowering_rule @@ -3011,13 +3000,12 @@ def _prng_random_bits_lowering_rule(ctx: LoweringRuleContext, *, shape): raise NotImplementedError("random_bits only supports rank>=2 outputs.") out_aval = ctx.avals_out[0] out_type = aval_to_ir_type(out_aval) - return tpu.PRNGRandomBitsOp(out_type).result + return tpu.prng_random_bits(out_type) lowering_rules[tpu_primitives.prng_random_bits_p] = _prng_random_bits_lowering_rule def random_seed_lowering(ctx, seeds, *, impl): - seed_lowering = lower_fun( - impl.seed, multiple_results=False) + seed_lowering = lower_fun(impl.seed, multiple_results=False) return seed_lowering(ctx, seeds) lowering_rules[prng.random_seed_p] = random_seed_lowering @@ -3040,8 +3028,7 @@ def new_lowering(key, bit_width, shape): def random_fold_in_lowering(ctx, keys, msgs): keys_aval, _ = ctx.avals_in impl = keys_aval.dtype._impl - fold_in_lowering = lower_fun( - impl.fold_in, multiple_results=False) + fold_in_lowering = lower_fun(impl.fold_in, multiple_results=False) return fold_in_lowering(ctx, keys, msgs) lowering_rules[prng.random_fold_in_p] = random_fold_in_lowering @@ -3061,7 +3048,7 @@ def random_unwrap_lowering(ctx, key): out_type = ir.VectorType.get( key.key_shape, _dtype_to_ir_type(jnp.dtype('int32')) ) - val = vector.BroadcastOp(out_type, scalar).result + val = vector.broadcast(out_type, scalar) return val lowering_rules[prng.random_unwrap_p] = random_unwrap_lowering @@ -3120,7 +3107,7 @@ def _checkify_lowering_rule( # so we need to compute not(pred) here. out_scalar_type = _dtype_to_ir_type(jnp.dtype('bool')) minus_one = ir_constant(-1, out_scalar_type) - not_pred = arith.XOrIOp(pred, minus_one).result + not_pred = arith.xori(pred, minus_one) attrs = {"msg": ir.StringAttr.get(exception.fmt_string)} ir.Operation.create("cf.assert", operands=(not_pred,),