diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index e02fdb365085..f928c3766f9b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -4349,38 +4349,56 @@ FailureOr> relayout( return emitError(v.getLoc(), "Not implemented: Both columns and rows are shifted"); } - if (col_diff < 0) { - return emitError(v.getLoc(), "Not implemented: Shifts to the left"); - } if (bitwidth != 32 || tiling != target_shape) { return emitError(v.getLoc(), "Not implemented: Only 32-bit column shifts for " "native layouts supported"); } - const int64_t sublane_diff = col_diff; TPU_ASSERT_GE_LOC(v.getLoc(), src_tiles.num_dimensions(), 1); std::optional maybe_create_mask; - if (src_tiles.dimensions()[src_tiles.num_dimensions() - 1] > 1) { + if (*(src_tiles.dimensions().end() - 1) > 1) { + int64_t lane_start, lane_end; + if (col_diff > 0) { + lane_start = 0; + lane_end = col_diff; + } else { // col_diff < 0 + lane_start = target_shape[1] + col_diff; + lane_end = target_shape[1]; + } auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder, v.getLoc()); maybe_create_mask = builder.create( v.getLoc(), VectorType::get(target_shape, builder.getI1Type()), - ValueRange{boundIdxConst(0), boundIdxConst(0)}, + ValueRange{boundIdxConst(0), boundIdxConst(lane_start)}, ValueRange{boundIdxConst(target_shape[0]), - boundIdxConst(col_diff)}); + boundIdxConst(lane_end)}); } - src_tiles.Each([&](absl::Span idx, Value tile) { - Value rot_tile = - builder - .create(v.getLoc(), tile, - /*amount=*/sublane_diff, - /*dimension=*/1, /*stride=*/nullptr, - /*stride_dimension=*/nullptr) - .getResult(); - if (idx[idx.size() - 1] != 0) { - SmallVector prev_idx(idx.begin(), idx.end()); - --prev_idx[idx.size() - 1]; - Value prev_rot_tile = dst_tiles(prev_idx); + dst_tiles.Each([&](absl::Span idx, Value *tile) { + *tile = builder + .create(v.getLoc(), *tile, + /*amount=*/col_diff < 0 + ? target_shape[1] + col_diff + : col_diff, + /*dimension=*/1, /*stride=*/nullptr, + /*stride_dimension=*/nullptr) + .getResult(); + }); + dst_tiles.Each([&](absl::Span idx, Value rot_tile) { + Value prev_rot_tile; + if (col_diff > 0) { + if (*(idx.end() - 1) != 0) { + SmallVector prev_idx(idx.begin(), idx.end()); + --*(prev_idx.end() - 1); + prev_rot_tile = dst_tiles(prev_idx); + } + } else { // col_diff < 0 + if (*(idx.end() - 1) != *(src_tiles.dimensions().end() - 1) - 1) { + SmallVector prev_idx(idx.begin(), idx.end()); + ++*(prev_idx.end() - 1); + prev_rot_tile = dst_tiles(prev_idx); + } + } + if (prev_rot_tile != nullptr) { rot_tile = builder.create( v.getLoc(), maybe_create_mask->getResult(), prev_rot_tile, rot_tile);