Skip to content

Commit

Permalink
[XLA:Mosaic] Support dynamic indices in strided load/store.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615931990
  • Loading branch information
bythew3i authored and jax authors committed Mar 14, 2024
1 parent ac41032 commit 7578e10
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 35 deletions.
8 changes: 4 additions & 4 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,12 @@ def TPU_LoadOp : TPU_Op<"load"> {
def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
let arguments = (ins
AnyMemRef:$base,
DenseI32ArrayAttr:$indices,
Variadic<Index>:$indices,
DenseI32ArrayAttr:$strides
);
let results = (outs AnyVector:$result);
let assemblyFormat = [{
$base attr-dict `:` type($base) `,` type($result)
$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)
}];
let hasVerifier = 1;
}
Expand All @@ -211,12 +211,12 @@ def TPU_StridedStoreOp : TPU_Op<"strided_store"> {
let arguments = (ins
AnyVector:$valueToStore,
AnyMemRef:$base,
DenseI32ArrayAttr:$indices,
Variadic<Index>:$indices,
DenseI32ArrayAttr:$strides
);
let results = (outs);
let assemblyFormat = [{
$base `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
$base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
}];
let hasVerifier = 1;
}
Expand Down
11 changes: 0 additions & 11 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,10 @@ LogicalResult verifyStridedOp(Op op, MemRefType memref_ty,
return failure();
}
for (int64_t i = 0; i < memref_ty.getRank(); ++i) {
if (indices[i] < 0 && indices[i] >= memref_ty.getDimSize(i)) {
op.emitError("Indices[")
<< i << "]=" << indices[i] << " is out of range [0, "
<< memref_ty.getDimSize(i) << ")";
return failure();
}
if (strides[i] < 1) {
op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= 1";
return failure();
}
if ((indices[i] + (vector_ty.getDimSize(i) - 1) * strides[i]) >
memref_ty.getDimSize(i)) {
op.emitError() << "Strided slice is out of range at dim " << i;
return failure();
}
}
return success();
}
Expand Down
28 changes: 13 additions & 15 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1159,9 +1159,9 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
}

LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
Value base_ref, const VectorType &vty,
Value base_ref, ValueRange indices,
const VectorType &vty,
const VectorLayout &layout,
const ArrayRef<int32_t> &indices,
const ArrayRef<int32_t> &strides) {
if (!isa<tpu::StridedLoadOp, tpu::StridedStoreOp>(op)) {
return op.emitOpError("Not implemented: Unsupported strided op")
Expand Down Expand Up @@ -1198,7 +1198,10 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
if (strides[rank - 1] != 1) {
return op.emitOpError("Not Implemented: Stride on last dim is not 1");
}
if (indices[rank - 1] != 0) {
auto last_idx = getIntConst(indices[rank - 1], /*silent=*/true);
if (failed(last_idx)) {
return op.emitOpError("Not Implemented: Dynamic index on last dim");
} else if (last_idx.value() != 0) {
return op.emitOpError("Not Implemented: Index on last dim is not 0");
}
ImplicitLocOpBuilder builder(op.getLoc(), &op);
Expand All @@ -1224,8 +1227,8 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
int64_t stride = (i < rank - 2)
? strides[i]
: (strides[i] * ctx.target_shape[i - rank + 2]);
idxs[i] =
IdxConst(indices[i] + tile_idxs[i] * stride, builder, op.getLoc());
idxs[i] = builder.create<arith::AddIOp>(
indices[i], IdxConst(tile_idxs[i] * stride, builder, op.getLoc()));
}
SmallVector<bool> sublane_mask(ctx.target_shape[0], true);
int64_t sublane_rem = vty.getDimSize(rank - 2) % ctx.target_shape[0];
Expand Down Expand Up @@ -1264,12 +1267,9 @@ LogicalResult tpu_strided_load_rule(RewriteContext &ctx, Operation &op,
TPU_ASSERT_OP(layouts_out.front().has_value());
const VectorLayout &layout_out = *layouts_out.front();
auto load_op = cast<tpu::StridedLoadOp>(op);
const auto base_ref = load_op.getBase();
const auto indices = load_op.getIndices();
const auto strides = load_op.getStrides();
const auto vty = cast<VectorType>(load_op.getResult().getType());
return strided_op_rule_impl(ctx, op, base_ref, vty, layout_out, indices,
strides);
return strided_op_rule_impl(ctx, op, load_op.getBase(), load_op.getIndices(),
vty, layout_out, load_op.getStrides());
}

// TODO(jevinjiang): maybe unify with vector store?
Expand All @@ -1283,12 +1283,10 @@ LogicalResult tpu_strided_store_rule(RewriteContext &ctx, Operation &op,

const VectorLayout &to_store_layout = *layouts_in.front();
auto store_op = cast<tpu::StridedStoreOp>(op);
const auto base_ref = store_op.getBase();
const auto indices = store_op.getIndices();
const auto strides = store_op.getStrides();
const auto vty = store_op.getValueToStore().getType();
return strided_op_rule_impl(ctx, op, base_ref, vty, to_store_layout, indices,
strides);
return strided_op_rule_impl(ctx, op, store_op.getBase(),
store_op.getIndices(), vty, to_store_layout,
store_op.getStrides());
}

LogicalResult matmul_rule_impl(RewriteContext &ctx, Operation &op,
Expand Down
30 changes: 26 additions & 4 deletions jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ rule_type as_generic_rule(void (*rule)(Op)) {

void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices,
ArrayRef<int64_t> window_shape,
ArrayRef<int64_t> full_shape) {
ArrayRef<int64_t> full_shape,
ArrayRef<int32_t> strides = {}) {
if (base_indices.size() != window_shape.size() ||
base_indices.size() != full_shape.size()) {
base_indices.size() != full_shape.size() ||
(!strides.empty() && base_indices.size() != strides.size())) {
return; // Malformed op.
}
if (base_indices.empty()) {
Expand All @@ -68,14 +70,15 @@ void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices,
for (auto [dim, access] :
llvm::enumerate(llvm::zip(base_indices, window_shape, full_shape))) {
auto [idx, size, bound] = access;
int64_t stride = strides.empty() ? 1 : strides[dim];
Value positive = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::sge, idx,
builder.create<arith::ConstantOp>(builder.getIntegerAttr(idx_type, 0)));
Value in_bounds = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::sle,
arith::CmpIPredicate::slt,
builder.create<arith::AddIOp>(
idx, builder.create<arith::ConstantOp>(
builder.getIntegerAttr(idx_type, size))),
builder.getIntegerAttr(idx_type, (size - 1) * stride))),
builder.create<arith::ConstantOp>(
builder.getIntegerAttr(idx_type, bound)));
std::string msg;
Expand Down Expand Up @@ -107,13 +110,32 @@ void tpu_memref_slice_rule(tpu::MemRefSliceOp op) {
/*full_shape=*/op.getMemRef().getType().getShape());
}

void tpu_strided_load_rule(tpu::StridedLoadOp op) {
assertIsValidSubwindow(op, op.getIndices(),
/*window_shape=*/op.getResult().getType().getShape(),
/*full_shape=*/op.getBase().getType().getShape(),
/*strides=*/op.getStrides());
}

void tpu_strided_store_rule(tpu::StridedStoreOp op) {
assertIsValidSubwindow(
op, op.getIndices(),
/*window_shape=*/op.getValueToStore().getType().getShape(),
/*full_shape=*/op.getBase().getType().getShape(),
/*strides=*/op.getStrides());
}

const llvm::StringMap<rule_type> &rules() {
static auto rules = new llvm::StringMap<rule_type>{
// TODO: tpu::LoadOp, tpu::StoreOp
{vector::LoadOp::getOperationName(), as_generic_rule(vector_load_rule)},
{vector::StoreOp::getOperationName(), as_generic_rule(vector_store_rule)},
{tpu::MemRefSliceOp::getOperationName(),
as_generic_rule(tpu_memref_slice_rule)},
{tpu::StridedLoadOp::getOperationName(),
as_generic_rule(tpu_strided_load_rule)},
{tpu::StridedStoreOp::getOperationName(),
as_generic_rule(tpu_strided_store_rule)},
};
return *rules;
}
Expand Down
4 changes: 3 additions & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,9 @@ class VectorLayoutInferer {
}
auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
setInLayout(op, {store_layout, kNoLayout});
SmallVector<Layout, 5> in_layout{op->getNumOperands(), kNoLayout};
in_layout[0] = store_layout;
setInLayout(op, in_layout);
return success();
}

Expand Down

0 comments on commit 7578e10

Please sign in to comment.