Skip to content

Commit

Permalink
[Mosaic] Column shift relayouts for non-native tilings and packed typ…
Browse files Browse the repository at this point in the history
…es, except for (1, n) and packed

PiperOrigin-RevId: 661091012
  • Loading branch information
tlongeri authored and jax authors committed Aug 9, 2024
1 parent f2068bb commit e57a7e3
Showing 1 changed file with 283 additions and 70 deletions.
353 changes: 283 additions & 70 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5008,13 +5008,290 @@ FailureOr<xla::Array<Value>> tpu_rotate_with_overflow(
return out_tiles;
}

void rotateVregs(OpBuilder &builder, xla::Array<Value> &vregs,
const int64_t amount, const int dimension) {
if (amount != 0) {
vregs.Each([&](absl::Span<const int64_t> idx, Value *vreg) {
CHECK(vreg);
*vreg = builder
.create<tpu::RotateOp>(vreg->getLoc(), *vreg,
/*amount=*/amount,
/*dimension=*/dimension,
/*stride=*/nullptr,
/*stride_dimension=*/nullptr)
.getResult();
});
}
};

void rotateSublanes(OpBuilder &builder, xla::Array<Value> &vregs,
const int64_t amount) {
rotateVregs(builder, vregs, amount, 0);
}

void rotateLanes(OpBuilder &builder, xla::Array<Value> &vregs,
const int64_t amount) {
rotateVregs(builder, vregs, amount, 1);
}

// Relayout src_vregs from layout src to layout dst, where dst is the same as
// src except that the column offset is dst_col_offset.
FailureOr<xla::Array<Value>> doColumnShiftRelayout(
OpBuilder &builder, const ArrayRef<int64_t> shape,
xla::Array<Value> src_vregs, const VectorLayout &src,
const int64_t dst_col_offset, const std::array<int64_t, 2> target_shape) {
CHECK(src.offsets()[1]);
const std::array<int64_t, 2> tiled_ishape =
src.getImplicitTiledDims(shape, 1);
const Location loc = src_vregs.begin()->getLoc();
const std::array<int64_t, 2> tiling = src.tiling();
const std::array<int64_t, 2> vreg_slice = src.vregSlice(target_shape);
const int bitwidth = src.bitwidth();
const int packing = src.packing();
const VectorLayout dst(bitwidth, {src.offsets()[0], dst_col_offset}, tiling,
src.implicit_dim());
const int64_t col_diff = dst_col_offset - *src.offsets()[1];
if (tiling[0] % packing != 0 || tiling[1] != target_shape[1]) {
return emitError(loc,
"Not implemented: Unsupported tiling for column shift");
}
// When shifting columns with multiple tiles per vreg, the overflowing
// columns of a tile move to the next tile, and they have to be shifted
// down. For example, for a 32-bit layout with (2, 128 tiling), when shifting
// a vreg right by 138 (128 + 10):
//
// +---------------+---------+ +---------+---------------+
// | 0:118 | 118:128 | |-138:-128| -128:-10 |
// +---------------+---------+ +---------+---------------+
// | 128:246 | 246:256 | | -10:0 | 0:118 |
// +---------------+---------+ -> +---------+---------------+
// | 256:382 | 382:392 | | 118:128 | 128:246 |
// +---------------+---------+ +---------+---------------+
// | 392:502 | 502:512 | | 246:256 | 256:382 |
// +---------------+---------+ +---------+---------------+
//
// The negative numbers above are used for column intervals coming from the
// previous vreg (if there is one).
//
// We can break the result vreg down into four parts:
//
// +---------+---------------+
// | UL | UR |
// + +---------------+
// | | LR |
// +---------+ +
// | LL | |
// + + +
// | | |
// +---------+---------------+
//
// Our example shifts right, which causes the upper parts to come from the
// previous (along the minor dim) vreg of the array (if it exists) and the
// lower parts to come from the original "current" vreg.
//
// - LR (Lower Right) comes from the current vreg lane-rotated by 10, and
// sublane-rotated down by 2 (1 tile).
// - LL (Lower Left) comes from the current vreg lane-rotated by 10, and
// sublane-rotated down by 4 (2 tiles).
// - UR (Upper Right) comes from the previous vreg lane-shifted by 10, and
// sublane-rotated down by 2 (1 tile).
// - UL (Upper Left) comes from the previous vreg lane-shifted by 10, and
// sublane-rotated down by 4 (2 tiles).
//
// This partitioning also works similarly for left shifts, except that the
// upper parts come from the current vreg, and the lower parts come from the
// next vreg.
//
// In general, for any tiling and shift amount, we will partition the result
// vreg into four like we did here. However, for some tilings and shift
// amounts, some of the partitions may be empty. There are some notable cases:
//
// - Tile-aligned shifts result in empty left parts.
// - Native tiling (a single tile per vreg) results in empty upper right and
// lower left parts.
// - Shifts right by less than 1 tile result in empty upper right parts, and
// shifts left by less than 1 tile result in empty lower left parts.

const int64_t sublanes_per_tile = src.sublanesPerTile(target_shape);
const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape);

int64_t split_offset = col_diff;
int64_t upper_idx_delta = -1;
int64_t lower_idx_delta = 0;
if (col_diff < 0) {
split_offset += vreg_slice[1];
++upper_idx_delta;
++lower_idx_delta;
}
const int64_t left_tile_split = llvm::divideCeil(split_offset, tiling[1]);
const int64_t right_tile_split = split_offset / tiling[1];
const int64_t left_right_split = split_offset % tiling[1];

rotateLanes(builder, src_vregs, left_right_split);
// TODO(tlongeri): Clean up. Some of these rotations may end up unused:
// - The left part of the first vreg and the right part of the last vreg
// may be entirely padding.
// - The entire left part may be unused if the shift is tile-aligned.
// They will be removed as dead code anyway, but it would be nicer to not
// generate them in the first place.
// Also, sometimes the rotation amount is 0, so we don't need to allocate
// another array (and we should steal the allocation for src_tiles, too).
xla::Array<Value> left_part = src_vregs;
xla::Array<Value> right_part = src_vregs;
rotateSublanes(builder, left_part,
left_tile_split * sublanes_per_tile % target_shape[0]);
rotateSublanes(builder, right_part,
right_tile_split * sublanes_per_tile % target_shape[0]);
// We assemble left and right, and then put them together.
// TODO(tlongeri): Lower and upper first is probably better, it can be
// reused for consecutive vregs. We can assemble lower_left+lower_right
// for one vreg and upper_left+upper_right for the next one in the same
// vselect. But the mask for assembling upper+lower is not as simple, so
// it might be a bit more expensive to generate. Worth it for large vreg
// arrays, I'm not sure about small ones (especially in older TPU gens).
const auto mask_vreg_ty = VectorType::get(
packing == 1
? target_shape
: ArrayRef<int64_t>{target_shape[0], target_shape[1], packing},
builder.getI1Type());
Value left_mask = nullptr;
Value right_mask = nullptr;
Value left_right_mask = nullptr;
auto get_left_mask = [&]() {
if (left_mask == nullptr) {
left_mask = builder.create<tpu::CreateMaskOp>(
loc, mask_vreg_ty,
ArrayRef<Value>{IdxConst(0, builder, loc), IdxConst(0, builder, loc)},
ArrayRef<Value>{
IdxConst(left_tile_split * sublanes_per_tile, builder, loc),
IdxConst(target_shape[1], builder, loc)});
}
return left_mask;
};
auto get_right_mask = [&]() {
if (right_mask == nullptr) {
right_mask = builder.create<tpu::CreateMaskOp>(
loc, mask_vreg_ty,
ArrayRef<Value>{IdxConst(0, builder, loc), IdxConst(0, builder, loc)},
ArrayRef<Value>{
IdxConst(right_tile_split * sublanes_per_tile, builder, loc),
IdxConst(target_shape[1], builder, loc)});
}
return right_mask;
};
auto get_left_right_mask = [&]() {
if (left_right_mask == nullptr) {
left_right_mask = builder.create<tpu::CreateMaskOp>(
loc, mask_vreg_ty,
ArrayRef<Value>{IdxConst(0, builder, loc), IdxConst(0, builder, loc)},
ArrayRef<Value>{IdxConst(target_shape[0], builder, loc),
IdxConst(left_right_split, builder, loc)});
}
return left_right_mask;
};
xla::Array<Value> dst_vregs(VectorLayout(bitwidth,
{src.offsets()[0], dst_col_offset},
tiling, src.implicit_dim())
.tileArrayImplicitShape(shape, target_shape));
dst_vregs.Each([&](absl::Span<const int64_t> dst_idx, Value *dst_vreg) {
SmallVector<int64_t> dst_idx_local(toArrayRef(dst_idx));
Value lower_left = nullptr;
Value lower_right = nullptr;
Value upper_left = nullptr;
Value upper_right = nullptr;
// Set parts if their size is non-empty and the source vreg exists.
*(dst_idx_local.end() - 1) += lower_idx_delta;
if (*(dst_idx_local.end() - 1) < *(src_vregs.dimensions().end() - 1)) {
if (left_tile_split < tiles_per_vreg && 0 < left_right_split) {
lower_left = left_part(dst_idx_local);
}
if (right_tile_split < tiles_per_vreg) {
lower_right = right_part(dst_idx_local);
}
}
*(dst_idx_local.end() - 1) -= lower_idx_delta;
*(dst_idx_local.end() - 1) += upper_idx_delta;
if (*(dst_idx_local.end() - 1) >= 0) {
if (0 < left_tile_split && 0 < left_right_split) {
upper_left = left_part(dst_idx_local);
}
if (0 < right_tile_split) {
upper_right = right_part(dst_idx_local);
}
}
*(dst_idx_local.end() - 1) -= upper_idx_delta;

// For the first and last vregs, some parts may be all padding, so
// unset them if this is the case. Note that the first and last vreg
// are the same when there is only one.
if (*(dst_idx_local.end() - 1) == 0) {
// We check the final offset (note that this is different from the rotate
// amount) against the thresholds of the last columns of vreg parts.
if (right_tile_split * tiling[1] <= dst_col_offset) {
// Note: When shifting right, UR is always all-padding.
upper_right = nullptr;
}
if (split_offset <= dst_col_offset) {
// Note: When shifting right, UL is always all-padding. When shifting
// left, UL is never all-padding (unless this is also the last vreg,
// possibly).
upper_left = nullptr;
}
if (vreg_slice[1] - tiling[1] + left_right_split <= dst_col_offset) {
// Note: When shifting right, LL is only all-padding if the source
// offset is in the last tile. When shifting left, LL is never
// all-padding (unless this is also the last vreg, possibly).
lower_left = nullptr;
}
}
if (*(dst_idx_local.end() - 1) == *(dst_vregs.dimensions().end() - 1) - 1) {
// We check the final end offset against the thresholds of the first
// columns of vreg parts.
const uint64_t end_offset =
(dst_col_offset + tiled_ishape[1] - 1) % vreg_slice[1] + 1;
if (end_offset <= left_tile_split * tiling[1]) {
// Note: When shifting left, LL is always all-padding.
lower_left = nullptr;
}
if (end_offset <= split_offset) {
// Note: When shifting left, LR is always all-padding. When shifting
// right, LR is never all-padding (unless this is also the first vreg,
// possibly).
lower_right = nullptr;
}
if (end_offset <= left_right_split) {
// Note: When shifting left, UR is only all-padding if the original
// end offset is in the first tile. When shifting right, UR is never
// all-padding (unless this is also the last vreg, possibly).
upper_right = nullptr;
}
}
// Combine parts into the final vreg (see comment in mask definitions).
auto combine_parts = [&builder](Value part1, Value part2,
auto get_mask_fn) -> Value {
if (part1 && part2) {
return builder.create<arith::SelectOp>(part1.getLoc(), get_mask_fn(),
part1, part2);
} else if (part1) {
return part1;
} else {
return part2;
}
};
Value left = combine_parts(upper_left, lower_left, get_left_mask);
Value right = combine_parts(upper_right, lower_right, get_right_mask);
*dst_vreg = combine_parts(left, right, get_left_right_mask);
CHECK(*dst_vreg);
});
return dst_vregs;
}

FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeOffsets(
OpBuilder &builder, const std::array<int64_t, 2> target_shape,
const Location loc, const VectorType vty, const VectorLayout src,
xla::Array<Value> vregs, const LayoutOffsets dst_offsets) {
const VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(),
src.implicit_dim());
const auto &tiling = src.tiling();
const int packing = src.packing();
const int8_t bitwidth = src.bitwidth();

Expand Down Expand Up @@ -5061,15 +5338,7 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeOffsets(
if (sublane_diff < 0) {
sublane_diff += target_shape[0];
}
vregs.Each([&](absl::Span<const int64_t> idx, Value *tile) {
*tile =
builder
.create<tpu::RotateOp>(loc, *tile,
/*amount=*/sublane_diff,
/*dimension=*/0, /*stride=*/nullptr,
/*stride_dimension=*/nullptr)
.getResult();
});
rotateSublanes(builder, vregs, sublane_diff);
}
const int src_subelem = *src.offsets()[0] % packing;
const int dst_subelem = *dst.offsets()[0] % packing;
Expand Down Expand Up @@ -5108,68 +5377,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeOffsets(
SmallVector<int64_t> dst_tiles_shape =
dst.tileArrayImplicitShape(vty.getShape(), target_shape);
CHECK_EQ(*(dst_tiles_shape.end() - 2), *(vregs.dimensions().end() - 2));
if (dst_tiles_shape.back() != vregs.dimensions().back()) {
return emitError(loc,
"Not implemented: Offsets changing the vreg array shape");
}

// TODO(tlongeri): Clean up col_diff and pass the dst offset directly.
if (col_diff != 0) {
if (bitwidth != 32 || tiling != target_shape) {
return emitError(loc,
"Not implemented: Only 32-bit column shifts for "
"native layouts supported");
}
TPU_ASSERT_GE_LOC(loc, vregs.num_dimensions(), 1);
std::optional<tpu::CreateMaskOp> maybe_create_mask;
if (*(vregs.dimensions().end() - 1) > 1) {
int64_t lane_start, lane_end;
if (col_diff > 0) {
lane_start = 0;
lane_end = col_diff;
} else { // col_diff < 0
lane_start = target_shape[1] + col_diff;
lane_end = target_shape[1];
}
auto boundIdxConst =
std::bind(IdxConst, std::placeholders::_1, builder, loc);
maybe_create_mask = builder.create<tpu::CreateMaskOp>(
loc, VectorType::get(target_shape, builder.getI1Type()),
ValueRange{boundIdxConst(0), boundIdxConst(lane_start)},
ValueRange{boundIdxConst(target_shape[0]), boundIdxConst(lane_end)});
}
auto rotated_vregs = vregs;
rotated_vregs.Each([&](absl::Span<const int64_t> idx, Value *tile) {
*tile = builder
.create<tpu::RotateOp>(loc, *tile,
/*amount=*/col_diff < 0
? target_shape[1] + col_diff
: col_diff,
/*dimension=*/1, /*stride=*/nullptr,
/*stride_dimension=*/nullptr)
.getResult();
});
vregs.Each([&](absl::Span<const int64_t> idx, Value *result) {
Value rot_tile = rotated_vregs(idx);
Value prev_rot_tile;
if (col_diff > 0) {
if (*(idx.end() - 1) != 0) {
SmallVector<int64_t> prev_idx(idx.begin(), idx.end());
--*(prev_idx.end() - 1);
prev_rot_tile = rotated_vregs(prev_idx);
}
} else { // col_diff < 0
if (*(idx.end() - 1) != *(rotated_vregs.dimensions().end() - 1) - 1) {
SmallVector<int64_t> prev_idx(idx.begin(), idx.end());
++*(prev_idx.end() - 1);
prev_rot_tile = rotated_vregs(prev_idx);
}
}
if (prev_rot_tile != nullptr) {
rot_tile = builder.create<arith::SelectOp>(
loc, maybe_create_mask->getResult(), prev_rot_tile, rot_tile);
}
*result = rot_tile;
});
FAILUREOR_ASSIGN_OR_RETURN(
vregs, doColumnShiftRelayout(builder, vty.getShape(), std::move(vregs),
src, *dst.offsets()[1], target_shape));
}
return std::make_pair(dst, std::move(vregs));
}
Expand Down

0 comments on commit e57a7e3

Please sign in to comment.