diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 68d24ed8f410..c0d1db245e17 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -197,12 +197,12 @@ def TPU_LoadOp : TPU_Op<"load"> { def TPU_StridedLoadOp : TPU_Op<"strided_load"> { let arguments = (ins AnyMemRef:$base, - DenseI32ArrayAttr:$indices, + Variadic:$indices, DenseI32ArrayAttr:$strides ); let results = (outs AnyVector:$result); let assemblyFormat = [{ - $base attr-dict `:` type($base) `,` type($result) + $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) }]; let hasVerifier = 1; } @@ -211,12 +211,12 @@ def TPU_StridedStoreOp : TPU_Op<"strided_store"> { let arguments = (ins AnyVector:$valueToStore, AnyMemRef:$base, - DenseI32ArrayAttr:$indices, + Variadic:$indices, DenseI32ArrayAttr:$strides ); let results = (outs); let assemblyFormat = [{ - $base `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) + $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) }]; let hasVerifier = 1; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 53941cbd2a9d..fd8785684346 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -202,21 +202,10 @@ LogicalResult verifyStridedOp(Op op, MemRefType memref_ty, return failure(); } for (int64_t i = 0; i < memref_ty.getRank(); ++i) { - if (indices[i] < 0 && indices[i] >= memref_ty.getDimSize(i)) { - op.emitError("Indices[") - << i << "]=" << indices[i] << " is out of range [0, " - << memref_ty.getDimSize(i) << ")"; - return failure(); - } if (strides[i] < 1) { op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= 1"; return failure(); } - if ((indices[i] + (vector_ty.getDimSize(i) - 1) * strides[i]) > - memref_ty.getDimSize(i)) { - op.emitError() << "Strided slice is out of range at dim " << i; - return failure(); - } } return success(); } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index e878ebd587a0..e02fdb365085 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1159,9 +1159,9 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op, } LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op, - Value base_ref, const VectorType &vty, + Value base_ref, ValueRange indices, + const VectorType &vty, const VectorLayout &layout, - const ArrayRef &indices, const ArrayRef &strides) { if (!isa(op)) { return op.emitOpError("Not implemented: Unsupported strided op") @@ -1198,7 +1198,10 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op, if (strides[rank - 1] != 1) { return op.emitOpError("Not Implemented: Stride on last dim is not 1"); } - if (indices[rank - 1] != 0) { + auto last_idx = getIntConst(indices[rank - 1], /*silent=*/true); + if (failed(last_idx)) { + return op.emitOpError("Not Implemented: Dynamic index on last dim"); + } else if (last_idx.value() != 0) { return op.emitOpError("Not Implemented: Index on last dim is not 0"); } ImplicitLocOpBuilder builder(op.getLoc(), &op); @@ -1224,8 +1227,8 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op, int64_t stride = (i < rank - 2) ? strides[i] : (strides[i] * ctx.target_shape[i - rank + 2]); - idxs[i] = - IdxConst(indices[i] + tile_idxs[i] * stride, builder, op.getLoc()); + idxs[i] = builder.create( + indices[i], IdxConst(tile_idxs[i] * stride, builder, op.getLoc())); } SmallVector sublane_mask(ctx.target_shape[0], true); int64_t sublane_rem = vty.getDimSize(rank - 2) % ctx.target_shape[0]; @@ -1264,12 +1267,9 @@ LogicalResult tpu_strided_load_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_OP(layouts_out.front().has_value()); const VectorLayout &layout_out = *layouts_out.front(); auto load_op = cast(op); - const auto base_ref = load_op.getBase(); - const auto indices = load_op.getIndices(); - const auto strides = load_op.getStrides(); const auto vty = cast(load_op.getResult().getType()); - return strided_op_rule_impl(ctx, op, base_ref, vty, layout_out, indices, - strides); + return strided_op_rule_impl(ctx, op, load_op.getBase(), load_op.getIndices(), + vty, layout_out, load_op.getStrides()); } // TODO(jevinjiang): maybe unify with vector store? @@ -1283,12 +1283,10 @@ LogicalResult tpu_strided_store_rule(RewriteContext &ctx, Operation &op, const VectorLayout &to_store_layout = *layouts_in.front(); auto store_op = cast(op); - const auto base_ref = store_op.getBase(); - const auto indices = store_op.getIndices(); - const auto strides = store_op.getStrides(); const auto vty = store_op.getValueToStore().getType(); - return strided_op_rule_impl(ctx, op, base_ref, vty, to_store_layout, indices, - strides); + return strided_op_rule_impl(ctx, op, store_op.getBase(), + store_op.getIndices(), vty, to_store_layout, + store_op.getStrides()); } LogicalResult matmul_rule_impl(RewriteContext &ctx, Operation &op, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc index 1eb7c12a828a..8e1246299abf 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc @@ -55,9 +55,11 @@ rule_type as_generic_rule(void (*rule)(Op)) { void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices, ArrayRef window_shape, - ArrayRef full_shape) { + ArrayRef full_shape, + ArrayRef strides = {}) { if (base_indices.size() != window_shape.size() || - base_indices.size() != full_shape.size()) { + base_indices.size() != full_shape.size() || + (!strides.empty() && base_indices.size() != strides.size())) { return; // Malformed op. } if (base_indices.empty()) { @@ -68,14 +70,15 @@ void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices, for (auto [dim, access] : llvm::enumerate(llvm::zip(base_indices, window_shape, full_shape))) { auto [idx, size, bound] = access; + int64_t stride = strides.empty() ? 1 : strides[dim]; Value positive = builder.create( arith::CmpIPredicate::sge, idx, builder.create(builder.getIntegerAttr(idx_type, 0))); Value in_bounds = builder.create( - arith::CmpIPredicate::sle, + arith::CmpIPredicate::slt, builder.create( idx, builder.create( - builder.getIntegerAttr(idx_type, size))), + builder.getIntegerAttr(idx_type, (size - 1) * stride))), builder.create( builder.getIntegerAttr(idx_type, bound))); std::string msg; @@ -107,6 +110,21 @@ void tpu_memref_slice_rule(tpu::MemRefSliceOp op) { /*full_shape=*/op.getMemRef().getType().getShape()); } +void tpu_strided_load_rule(tpu::StridedLoadOp op) { + assertIsValidSubwindow(op, op.getIndices(), + /*window_shape=*/op.getResult().getType().getShape(), + /*full_shape=*/op.getBase().getType().getShape(), + /*strides=*/op.getStrides()); +} + +void tpu_strided_store_rule(tpu::StridedStoreOp op) { + assertIsValidSubwindow( + op, op.getIndices(), + /*window_shape=*/op.getValueToStore().getType().getShape(), + /*full_shape=*/op.getBase().getType().getShape(), + /*strides=*/op.getStrides()); +} + const llvm::StringMap &rules() { static auto rules = new llvm::StringMap{ // TODO: tpu::LoadOp, tpu::StoreOp @@ -114,6 +132,10 @@ const llvm::StringMap &rules() { {vector::StoreOp::getOperationName(), as_generic_rule(vector_store_rule)}, {tpu::MemRefSliceOp::getOperationName(), as_generic_rule(tpu_memref_slice_rule)}, + {tpu::StridedLoadOp::getOperationName(), + as_generic_rule(tpu_strided_load_rule)}, + {tpu::StridedStoreOp::getOperationName(), + as_generic_rule(tpu_strided_store_rule)}, }; return *rules; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 59f1a996532a..a69bcc3d88e6 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -616,7 +616,9 @@ class VectorLayoutInferer { } auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth), ImplicitDim::kNone); - setInLayout(op, {store_layout, kNoLayout}); + SmallVector in_layout{op->getNumOperands(), kNoLayout}; + in_layout[0] = store_layout; + setInLayout(op, in_layout); return success(); }