Skip to content

Commit

Permalink
[Mosaic][NFC] Use TypedValue<VectorType> instead of Value for app…
Browse files Browse the repository at this point in the history
…licable arguments/return values in `disassemble` and `relayout`

Ideally we would prefer `TypedValue<VectorType>` everywhere possible for static type checking. However, I tried the type for arrays of vregs, `xla::Array<Value>` to `xla::Array<TypedValue<VectorType>>` and ran into issues because MLIR support for arrays/ranges of `TypedValue`s seems lacking.

For example, I can't find a good way to get a `ValueRange` (which many op constructors take) from an array of `TypedValue`s without creating an intermediate vector of `Value`s. Perhaps an unsafe cast if we make the (probably not guaranteed) assumption that `sizeof(TypedValue)` equals `sizeof(Value)`.

Also note that MLIR itself uses untyped `Value`s for ranges of op results and operands even when the op definition declares them to be of a specific type.

PiperOrigin-RevId: 610509743
  • Loading branch information
tlongeri authored and jax authors committed Feb 26, 2024
1 parent ca1844d commit 57e34e1
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 59 deletions.
11 changes: 8 additions & 3 deletions jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,10 @@ MlirTpuValueArray mlirTpuDisassemble(MlirTpuInsertionPoint insertion_point,
MlirTpuVectorLayout layout, MlirValue val,
MlirTpuI64TargetTuple target_shape) {
mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point);
// This cast will fail and assert if the caller passed a non-vector
auto vector_val = mlir::cast<mlir::TypedValue<mlir::VectorType>>(unwrap(val));
mlir::FailureOr<xla::Array<mlir::Value>> failure_or_vals =
mlir::tpu::disassemble(builder, *unwrap(layout), unwrap(val),
mlir::tpu::disassemble(builder, *unwrap(layout), vector_val,
unwrap(target_shape));
if (failed(failure_or_vals)) {
return {{nullptr, 0}, nullptr};
Expand All @@ -371,8 +373,11 @@ MlirValue mlirTpuRelayout(MlirTpuInsertionPoint insertion_point, MlirValue val,
MlirTpuVectorLayout src, MlirTpuVectorLayout dst,
MlirTpuI64TargetTuple target_shape) {
mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point);
mlir::FailureOr<mlir::Value> failure_or_new_val = mlir::tpu::relayout(
builder, unwrap(val), *unwrap(src), *unwrap(dst), unwrap(target_shape));
// This cast will fail and assert if the caller passed a non-vector
auto vector_val = mlir::cast<mlir::TypedValue<mlir::VectorType>>(unwrap(val));
mlir::FailureOr<mlir::TypedValue<mlir::VectorType>> failure_or_new_val =
mlir::tpu::relayout(builder, vector_val, *unwrap(src), *unwrap(dst),
unwrap(target_shape));
if (failed(failure_or_new_val)) {
return {nullptr};
}
Expand Down
8 changes: 4 additions & 4 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -325,16 +325,16 @@ def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> {
}

def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> {
let arguments = (ins Variadic<AnyType>:$input);
let results = (outs AnyType:$output);
let arguments = (ins Variadic<AnyVector>:$input);
let results = (outs AnyVector:$output);
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type($output)
}];
}

def TPU_UnrollVectorsOp : TPU_Op<"unroll_vectors", [Pure]> {
let arguments = (ins AnyType:$input);
let results = (outs Variadic<AnyType>:$output);
let arguments = (ins AnyVector:$input);
let results = (outs Variadic<AnyVector>:$output);
let hasCanonicalizeMethod = 1;
let assemblyFormat = [{
$input attr-dict `:` type($input) `->` type($output)
Expand Down
100 changes: 52 additions & 48 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -590,9 +590,11 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
SmallVector<xla::Array<Value>> in_vreg_arrays;
in_vreg_arrays.reserve(num_operands);
for (unsigned i = 0; i < num_operands; ++i) {
FAILUREOR_ASSIGN_OR_RETURN(xla::Array<Value> tile_array,
disassemble(builder, *layouts_in[i],
op.getOperand(i), ctx.target_shape));
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> tile_array,
disassemble(builder, *layouts_in[i],
cast<TypedValue<VectorType>>(op.getOperand(i)),
ctx.target_shape));
in_vreg_arrays.emplace_back(std::move(tile_array));
}

Expand Down Expand Up @@ -653,15 +655,16 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_in,
const VectorLayout &layout_out) {
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
auto result_ty = cast<VectorType>(op.getResult().getType());
auto source_ty = cast<VectorType>(op.getIn().getType());
const auto result_ty = cast<VectorType>(op.getResult().getType());
auto source = cast<TypedValue<VectorType>>(op.getIn());
const auto source_ty = source.getType();
if (layout_out.bitwidth() != 32) {
return op.emitOpError(
"Not implemented: Only extensions to 32-bit supported");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> input_vregs,
disassemble(builder, layout_in, op.getIn(), ctx.target_shape));
disassemble(builder, layout_in, source, ctx.target_shape));
xla::Array<Value> output_vregs(
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
FAILUREOR_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -762,7 +765,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
auto result_ty = cast<VectorType>(op.getResult().getType());
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> input_vregs,
disassemble(builder, layout_in, op.getIn(), ctx.target_shape));
disassemble(builder, layout_in, cast<TypedValue<VectorType>>(op.getIn()),
ctx.target_shape));
xla::Array<Value> output_vregs(
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
if (layout_in.bitwidth() != 32) {
Expand Down Expand Up @@ -905,13 +909,13 @@ LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op,
}
continue;
}
if (auto vty = dyn_cast<VectorType>(operand.getType())) {
if (auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand)) {
if (!layout.has_value()) {
return op.emitOpError("Expected layout for vector operand");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> tiles,
disassemble(builder, *layout, operand, ctx.target_shape));
disassemble(builder, *layout, vector_operand, ctx.target_shape));
unrolled_args.append(tiles.begin(), tiles.end());
} else {
if (layout.has_value()) {
Expand Down Expand Up @@ -1098,12 +1102,12 @@ LogicalResult scf_yield_rule(RewriteContext &ctx, Operation &op,
SmallVector<Value> unrolled;
for (auto [operand, layout] :
llvm::zip_equal(yield_op.getOperands(), layouts_in)) {
if (auto vty = dyn_cast<VectorType>(operand.getType())) {
if (auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand)) {
// When the operand has vector type, disassemble the operand.
TPU_ASSERT_OP(layout.has_value());
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> tiles,
disassemble(builder, *layout, operand, ctx.target_shape));
disassemble(builder, *layout, vector_operand, ctx.target_shape));
unrolled.append(tiles.begin(), tiles.end());
} else {
TPU_ASSERT_OP(!layout.has_value());
Expand Down Expand Up @@ -1745,7 +1749,8 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
for (Value operand : concatenate_op.getOperands()) {
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> t,
disassemble(builder, layout, operand, ctx.target_shape));
disassemble(builder, layout, cast<TypedValue<VectorType>>(operand),
ctx.target_shape));
tiles.emplace_back(std::move(t));
}
const xla::Array<Value> res_tiles = concatenate(tiles, dimension);
Expand Down Expand Up @@ -2227,7 +2232,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,
const VectorType dst_ty = broadcast_op.getResult().getType();
const SmallVector<int64_t> dst_tiles_shape =
layout_out.tileArrayShape(dst_ty.getShape(), ctx.target_shape);
if (auto src_ty = dyn_cast<VectorType>(broadcast_op.getSourceType())) {
if (auto src = dyn_cast<TypedValue<VectorType>>(broadcast_op.getSource())) {
VectorType src_ty = src.getType();
TPU_ASSERT_OP(maybe_layout_in.has_value());
const VectorLayout &layout_in = *maybe_layout_in;
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
Expand Down Expand Up @@ -2301,8 +2307,7 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op,

FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> src_tiles,
disassemble(builder, layout_in, broadcast_op.getSource(),
ctx.target_shape));
disassemble(builder, layout_in, src, ctx.target_shape));
xla::Array<Value> dst_tiles(dst_tiles_shape);
if (no_op) {
SmallVector<int64_t> reshape_dims(expand_rank, 1);
Expand Down Expand Up @@ -2666,10 +2671,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
return multi_reduction_op.emitOpError(
"Not implemented: Can only reduce into vectors");
}
if (!layouts_out.front().has_value()) {
// Shouldn't be empty since result is a vector
return op.emitOpError("Expected non-null output layout");
}
// Op definition enforces that accumulator type must match result type
auto acc = cast<TypedValue<VectorType>>(multi_reduction_op.getAcc());
TPU_ASSERT_OP(layouts_out.front().has_value());

const ArrayAttr dim_attrs = multi_reduction_op.getReductionDims();
SmallVector<int64_t> dims;
Expand All @@ -2686,11 +2690,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> acc_vregs,
disassemble(builder, acc_layout, multi_reduction_op.getAcc(),
ctx.target_shape));
const Value acc_vreg = *acc_vregs.begin();
auto acc_def =
dyn_cast_if_present<arith::ConstantOp>(acc_vreg.getDefiningOp());
disassemble(builder, acc_layout, acc, ctx.target_shape));
auto acc_def = dyn_cast_if_present<arith::ConstantOp>(
acc_vregs.begin()->getDefiningOp());
if (acc_def == nullptr) {
return multi_reduction_op.emitOpError(
"Not implemented: Only constant accumulator supported");
Expand Down Expand Up @@ -2838,7 +2840,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
}
xla::Array<Value> reduced_vregs =
src_vregs.Slice(src_slice_start, src_slice_end);
std::optional<Value> acc;
std::optional<Value> acc_vreg;
auto reduction_status = reduced_vregs.EachStatus(
[&](const absl::Span<const int64_t> red_idx,
Value *const src_vreg) {
Expand All @@ -2860,17 +2862,17 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
return absl::UnknownError("");
}
Value vreg = failure_or_vreg.value();
if (!acc.has_value()) {
acc = vreg;
if (!acc_vreg.has_value()) {
acc_vreg = vreg;
} else {
switch (tpu_kind) {
case tpu::ReductionKind::SUM:
acc = builder.create<arith::AddFOp>(vreg.getLoc(), *acc,
vreg);
acc_vreg = builder.create<arith::AddFOp>(vreg.getLoc(),
*acc_vreg, vreg);
break;
case tpu::ReductionKind::MAX:
acc = builder.create<arith::MaximumFOp>(vreg.getLoc(), *acc,
vreg);
acc_vreg = builder.create<arith::MaximumFOp>(
vreg.getLoc(), *acc_vreg, vreg);
break;
}
}
Expand All @@ -2879,16 +2881,16 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
if (!reduction_status.ok()) {
return reduction_status;
}
TPU_ASSERT_OP(acc.has_value());
TPU_ASSERT_OP(acc_vreg.has_value());
if (reduces[1]) {
acc = builder.create<tpu::AllReduceOp>(multi_reduction_op->getLoc(),
*acc, 1, tpu_kind);
acc_vreg = builder.create<tpu::AllReduceOp>(
multi_reduction_op->getLoc(), *acc_vreg, 1, tpu_kind);
}
if (reduces[0]) {
acc = builder.create<tpu::AllReduceOp>(multi_reduction_op->getLoc(),
*acc, 0, tpu_kind);
acc_vreg = builder.create<tpu::AllReduceOp>(
multi_reduction_op->getLoc(), *acc_vreg, 0, tpu_kind);
}
*dst_vreg = *acc;
*dst_vreg = *acc_vreg;
return absl::OkStatus();
});
if (!all_results_ok.ok()) {
Expand Down Expand Up @@ -3478,9 +3480,10 @@ RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
// Returns:
// An ndarray of MLIR values representing the tiling of val given by layout.
FailureOr<xla::Array<Value>> disassemble(
OpBuilder &builder, const VectorLayout &layout, const Value val,
OpBuilder &builder, const VectorLayout &layout,
const TypedValue<VectorType> val,
const std::array<int64_t, 2> target_shape) {
const auto vty = cast<VectorType>(val.getType());
const auto vty = val.getType();
const auto op_result = dyn_cast<OpResult>(val);
if (op_result == nullptr) {
return failure();
Expand Down Expand Up @@ -3869,15 +3872,15 @@ Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx,
}

// TODO(apaszke): Test this function properly
FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
const VectorLayout &dst,
const std::array<int64_t, 2> target_shape) {
FailureOr<TypedValue<VectorType>> relayout(
OpBuilder &builder, TypedValue<VectorType> v, VectorLayout src,
const VectorLayout &dst, const std::array<int64_t, 2> target_shape) {
const int8_t bitwidth = src.bitwidth();
if (bitwidth != dst.bitwidth()) {
return emitError(v.getLoc(), "Can't change bitwidth during a relayout");
}
const int packing = src.packing();
VectorType vty = cast<VectorType>(v.getType());
VectorType vty = v.getType();
FAILUREOR_ASSIGN_OR_RETURN(xla::Array<Value> src_tiles,
disassemble(builder, src, v, target_shape));
SmallVector<int64_t> dst_tiles_shape =
Expand Down Expand Up @@ -4202,16 +4205,17 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
for (auto [idx, tup] :
llvm::enumerate(llvm::zip(op.getOperands(), layouts_in))) {
auto [operand, li] = tup;
auto vty = dyn_cast<VectorType>(operand.getType());
TPU_ASSERT_EQ_OP(vty != nullptr, li.has_value());
if (vty == nullptr) {
auto vector_operand = dyn_cast<TypedValue<VectorType>>(operand);
TPU_ASSERT_EQ_OP(vector_operand != nullptr, li.has_value());
if (vector_operand == nullptr) {
continue;
}
auto vty = vector_operand.getType();

// The operand should always be an Operation (and not a BlockArgument)
// since we expect the FuncOp to have only memrefs and semaphores as
// arguments.
auto op_result = dyn_cast<OpResult>(operand);
auto op_result = dyn_cast<OpResult>(vector_operand);
if (op_result == nullptr) {
return op.emitError("Expected operand to be an operation result");
}
Expand All @@ -4227,7 +4231,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
}
OpBuilder builder(&op);
FAILUREOR_ASSIGN_OR_RETURN(Value new_v,
relayout(builder, operand, /*src=*/*lo,
relayout(builder, vector_operand, /*src=*/*lo,
/*dst=*/*li, ctx.target_shape));
op.setOperand(idx, new_v);
}
Expand Down
11 changes: 7 additions & 4 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ RollVectorsOp assemble(OpBuilder &builder, VectorType vty,
const xla::Array<Value> &vals,
std::array<int64_t, 2> target_shape);
FailureOr<xla::Array<Value>> disassemble(OpBuilder &builder,
const VectorLayout &layout, Value val,
const VectorLayout &layout,
TypedValue<VectorType> val,
std::array<int64_t, 2> target_shape);

// Rewrites the operation according to its layout annotations.
Expand All @@ -55,9 +56,11 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op);
//
// Returns:
// A new MLIR vector value, laid out as requested by dst.
FailureOr<Value> relayout(OpBuilder &builder, Value v, VectorLayout src,
const VectorLayout &dst,
std::array<int64_t, 2> target_shape);
FailureOr<TypedValue<VectorType>> relayout(OpBuilder &builder,
TypedValue<VectorType> v,
VectorLayout src,
const VectorLayout &dst,
std::array<int64_t, 2> target_shape);

} // namespace mlir::tpu

Expand Down

0 comments on commit 57e34e1

Please sign in to comment.