Skip to content

Commit

Permalink
[Mosaic] Support left shifting relayouts
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 616335006
  • Loading branch information
tlongeri authored and jax authors committed Mar 16, 2024
1 parent ab2e906 commit 465b40a
Showing 1 changed file with 37 additions and 19 deletions.
56 changes: 37 additions & 19 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4349,38 +4349,56 @@ FailureOr<TypedValue<VectorType>> relayout(
return emitError(v.getLoc(),
"Not implemented: Both columns and rows are shifted");
}
if (col_diff < 0) {
return emitError(v.getLoc(), "Not implemented: Shifts to the left");
}
if (bitwidth != 32 || tiling != target_shape) {
return emitError(v.getLoc(),
"Not implemented: Only 32-bit column shifts for "
"native layouts supported");
}
const int64_t sublane_diff = col_diff;
TPU_ASSERT_GE_LOC(v.getLoc(), src_tiles.num_dimensions(), 1);
std::optional<tpu::CreateMaskOp> maybe_create_mask;
if (src_tiles.dimensions()[src_tiles.num_dimensions() - 1] > 1) {
if (*(src_tiles.dimensions().end() - 1) > 1) {
int64_t lane_start, lane_end;
if (col_diff > 0) {
lane_start = 0;
lane_end = col_diff;
} else { // col_diff < 0
lane_start = target_shape[1] + col_diff;
lane_end = target_shape[1];
}
auto boundIdxConst =
std::bind(IdxConst, std::placeholders::_1, builder, v.getLoc());
maybe_create_mask = builder.create<tpu::CreateMaskOp>(
v.getLoc(), VectorType::get(target_shape, builder.getI1Type()),
ValueRange{boundIdxConst(0), boundIdxConst(0)},
ValueRange{boundIdxConst(0), boundIdxConst(lane_start)},
ValueRange{boundIdxConst(target_shape[0]),
boundIdxConst(col_diff)});
boundIdxConst(lane_end)});
}
src_tiles.Each([&](absl::Span<const int64_t> idx, Value tile) {
Value rot_tile =
builder
.create<tpu::RotateOp>(v.getLoc(), tile,
/*amount=*/sublane_diff,
/*dimension=*/1, /*stride=*/nullptr,
/*stride_dimension=*/nullptr)
.getResult();
if (idx[idx.size() - 1] != 0) {
SmallVector<int64_t> prev_idx(idx.begin(), idx.end());
--prev_idx[idx.size() - 1];
Value prev_rot_tile = dst_tiles(prev_idx);
dst_tiles.Each([&](absl::Span<const int64_t> idx, Value *tile) {
*tile = builder
.create<tpu::RotateOp>(v.getLoc(), *tile,
/*amount=*/col_diff < 0
? target_shape[1] + col_diff
: col_diff,
/*dimension=*/1, /*stride=*/nullptr,
/*stride_dimension=*/nullptr)
.getResult();
});
dst_tiles.Each([&](absl::Span<const int64_t> idx, Value rot_tile) {
Value prev_rot_tile;
if (col_diff > 0) {
if (*(idx.end() - 1) != 0) {
SmallVector<int64_t> prev_idx(idx.begin(), idx.end());
--*(prev_idx.end() - 1);
prev_rot_tile = dst_tiles(prev_idx);
}
} else { // col_diff < 0
if (*(idx.end() - 1) != *(src_tiles.dimensions().end() - 1) - 1) {
SmallVector<int64_t> prev_idx(idx.begin(), idx.end());
++*(prev_idx.end() - 1);
prev_rot_tile = dst_tiles(prev_idx);
}
}
if (prev_rot_tile != nullptr) {
rot_tile = builder.create<arith::SelectOp>(
v.getLoc(), maybe_create_mask->getResult(), prev_rot_tile,
rot_tile);
Expand Down

0 comments on commit 465b40a

Please sign in to comment.