Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:Mosaic] Support dynamic indices in strided load/store. #20220

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,12 @@ def TPU_LoadOp : TPU_Op<"load"> {
def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
let arguments = (ins
AnyMemRef:$base,
DenseI32ArrayAttr:$indices,
Variadic<Index>:$indices,
DenseI32ArrayAttr:$strides
);
let results = (outs AnyVector:$result);
let assemblyFormat = [{
$base attr-dict `:` type($base) `,` type($result)
$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)
}];
let hasVerifier = 1;
}
Expand All @@ -211,12 +211,12 @@ def TPU_StridedStoreOp : TPU_Op<"strided_store"> {
let arguments = (ins
AnyVector:$valueToStore,
AnyMemRef:$base,
DenseI32ArrayAttr:$indices,
Variadic<Index>:$indices,
DenseI32ArrayAttr:$strides
);
let results = (outs);
let assemblyFormat = [{
$base `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
$base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
}];
let hasVerifier = 1;
}
Expand Down
11 changes: 0 additions & 11 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,10 @@ LogicalResult verifyStridedOp(Op op, MemRefType memref_ty,
return failure();
}
for (int64_t i = 0; i < memref_ty.getRank(); ++i) {
if (indices[i] < 0 && indices[i] >= memref_ty.getDimSize(i)) {
op.emitError("Indices[")
<< i << "]=" << indices[i] << " is out of range [0, "
<< memref_ty.getDimSize(i) << ")";
return failure();
}
if (strides[i] < 1) {
op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= 1";
return failure();
}
if ((indices[i] + (vector_ty.getDimSize(i) - 1) * strides[i]) >
memref_ty.getDimSize(i)) {
op.emitError() << "Strided slice is out of range at dim " << i;
return failure();
}
}
return success();
}
Expand Down
28 changes: 13 additions & 15 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1159,9 +1159,9 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
}

LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
Value base_ref, const VectorType &vty,
Value base_ref, ValueRange indices,
const VectorType &vty,
const VectorLayout &layout,
const ArrayRef<int32_t> &indices,
const ArrayRef<int32_t> &strides) {
if (!isa<tpu::StridedLoadOp, tpu::StridedStoreOp>(op)) {
return op.emitOpError("Not implemented: Unsupported strided op")
Expand Down Expand Up @@ -1198,7 +1198,10 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
if (strides[rank - 1] != 1) {
return op.emitOpError("Not Implemented: Stride on last dim is not 1");
}
if (indices[rank - 1] != 0) {
auto last_idx = getIntConst(indices[rank - 1], /*silent=*/true);
if (failed(last_idx)) {
return op.emitOpError("Not Implemented: Dynamic index on last dim");
} else if (last_idx.value() != 0) {
return op.emitOpError("Not Implemented: Index on last dim is not 0");
}
ImplicitLocOpBuilder builder(op.getLoc(), &op);
Expand All @@ -1224,8 +1227,8 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
int64_t stride = (i < rank - 2)
? strides[i]
: (strides[i] * ctx.target_shape[i - rank + 2]);
idxs[i] =
IdxConst(indices[i] + tile_idxs[i] * stride, builder, op.getLoc());
idxs[i] = builder.create<arith::AddIOp>(
indices[i], IdxConst(tile_idxs[i] * stride, builder, op.getLoc()));
}
SmallVector<bool> sublane_mask(ctx.target_shape[0], true);
int64_t sublane_rem = vty.getDimSize(rank - 2) % ctx.target_shape[0];
Expand Down Expand Up @@ -1264,12 +1267,9 @@ LogicalResult tpu_strided_load_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_OP(layouts_out.front().has_value());
const VectorLayout &layout_out = *layouts_out.front();
auto load_op = cast<tpu::StridedLoadOp>(op);
const auto base_ref = load_op.getBase();
const auto indices = load_op.getIndices();
const auto strides = load_op.getStrides();
const auto vty = cast<VectorType>(load_op.getResult().getType());
return strided_op_rule_impl(ctx, op, base_ref, vty, layout_out, indices,
strides);
return strided_op_rule_impl(ctx, op, load_op.getBase(), load_op.getIndices(),
vty, layout_out, load_op.getStrides());
}

// TODO(jevinjiang): maybe unify with vector store?
Expand All @@ -1283,12 +1283,10 @@ LogicalResult tpu_strided_store_rule(RewriteContext &ctx, Operation &op,

const VectorLayout &to_store_layout = *layouts_in.front();
auto store_op = cast<tpu::StridedStoreOp>(op);
const auto base_ref = store_op.getBase();
const auto indices = store_op.getIndices();
const auto strides = store_op.getStrides();
const auto vty = store_op.getValueToStore().getType();
return strided_op_rule_impl(ctx, op, base_ref, vty, to_store_layout, indices,
strides);
return strided_op_rule_impl(ctx, op, store_op.getBase(),
store_op.getIndices(), vty, to_store_layout,
store_op.getStrides());
}

LogicalResult matmul_rule_impl(RewriteContext &ctx, Operation &op,
Expand Down
30 changes: 26 additions & 4 deletions jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ rule_type as_generic_rule(void (*rule)(Op)) {

void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices,
ArrayRef<int64_t> window_shape,
ArrayRef<int64_t> full_shape) {
ArrayRef<int64_t> full_shape,
ArrayRef<int32_t> strides = {}) {
if (base_indices.size() != window_shape.size() ||
base_indices.size() != full_shape.size()) {
base_indices.size() != full_shape.size() ||
(!strides.empty() && base_indices.size() != strides.size())) {
return; // Malformed op.
}
if (base_indices.empty()) {
Expand All @@ -68,14 +70,15 @@ void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices,
for (auto [dim, access] :
llvm::enumerate(llvm::zip(base_indices, window_shape, full_shape))) {
auto [idx, size, bound] = access;
int64_t stride = strides.empty() ? 1 : strides[dim];
Value positive = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::sge, idx,
builder.create<arith::ConstantOp>(builder.getIntegerAttr(idx_type, 0)));
Value in_bounds = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::sle,
arith::CmpIPredicate::slt,
builder.create<arith::AddIOp>(
idx, builder.create<arith::ConstantOp>(
builder.getIntegerAttr(idx_type, size))),
builder.getIntegerAttr(idx_type, (size - 1) * stride))),
builder.create<arith::ConstantOp>(
builder.getIntegerAttr(idx_type, bound)));
std::string msg;
Expand Down Expand Up @@ -107,13 +110,32 @@ void tpu_memref_slice_rule(tpu::MemRefSliceOp op) {
/*full_shape=*/op.getMemRef().getType().getShape());
}

void tpu_strided_load_rule(tpu::StridedLoadOp op) {
assertIsValidSubwindow(op, op.getIndices(),
/*window_shape=*/op.getResult().getType().getShape(),
/*full_shape=*/op.getBase().getType().getShape(),
/*strides=*/op.getStrides());
}

void tpu_strided_store_rule(tpu::StridedStoreOp op) {
assertIsValidSubwindow(
op, op.getIndices(),
/*window_shape=*/op.getValueToStore().getType().getShape(),
/*full_shape=*/op.getBase().getType().getShape(),
/*strides=*/op.getStrides());
}

const llvm::StringMap<rule_type> &rules() {
static auto rules = new llvm::StringMap<rule_type>{
// TODO: tpu::LoadOp, tpu::StoreOp
{vector::LoadOp::getOperationName(), as_generic_rule(vector_load_rule)},
{vector::StoreOp::getOperationName(), as_generic_rule(vector_store_rule)},
{tpu::MemRefSliceOp::getOperationName(),
as_generic_rule(tpu_memref_slice_rule)},
{tpu::StridedLoadOp::getOperationName(),
as_generic_rule(tpu_strided_load_rule)},
{tpu::StridedStoreOp::getOperationName(),
as_generic_rule(tpu_strided_store_rule)},
};
return *rules;
}
Expand Down
4 changes: 3 additions & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,9 @@ class VectorLayoutInferer {
}
auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
setInLayout(op, {store_layout, kNoLayout});
SmallVector<Layout, 5> in_layout{op->getNumOperands(), kNoLayout};
in_layout[0] = store_layout;
setInLayout(op, in_layout);
return success();
}

Expand Down