Skip to content

Commit

Permalink
Add support for elementwise op canonicalization in fp32 for older har…
Browse files Browse the repository at this point in the history
…dware.

PiperOrigin-RevId: 651959463
  • Loading branch information
Google-ML-Automation authored and jax authors committed Jul 13, 2024
1 parent 0dfb206 commit 764ec92
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 6 deletions.
3 changes: 2 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ std::pair<bool, bool> mightCommunicateBetweenChips(Operation* op);
std::unique_ptr<OperationPass<func::FuncOp>> createInferMemRefLayoutPass(
int hardware_generation = -1, const TpuTilingFlags &tpu_tiling_flags = {});

std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass();
std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
int hardware_generation = -1);

std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
int lane_count = 128, int sublane_count = 8);
Expand Down
108 changes: 103 additions & 5 deletions jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <functional>
#include <memory>
#include <vector>

#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand All @@ -8,6 +9,8 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
// NOLINTNEXTLINE(misc-include-cleaner)
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
Expand All @@ -19,6 +22,7 @@
#include "mlir/include/mlir/IR/Block.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/OpDefinition.h"
#include "mlir/include/mlir/IR/Operation.h"
#include "mlir/include/mlir/IR/Region.h"
#include "mlir/include/mlir/IR/Value.h"
Expand Down Expand Up @@ -111,6 +115,78 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) {
return success();
};

LogicalResult canonicalize_elementwise(int hardware_generation_,
Operation &op) {
OpBuilder builder(&op);
auto operands = op.getOperands();
auto res_ty = dyn_cast<VectorType>(op.getResult(0).getType());
if (op.getNumResults() != 1) {
op.emitOpError("Invariant violated: Unexpected number of results");
return failure();
}
if (!res_ty) {
// scalar
// TODO(mvoz): Add canonicalization and invariants for scalar elementwise
// ops.
return success();
}
auto shape = res_ty.getShape();
std::vector<Value> new_operands;
new_operands.reserve(operands.size());

bool should_rewrite_op = false;
auto target_f32_ty = VectorType::get(shape, builder.getF32Type());
for (int i = 0; i < operands.size(); ++i) {
auto operand = operands[i];
auto ty = dyn_cast<VectorType>(operand.getType());
if (ty) {
if (ty.getShape() != shape) {
// Should already be checked my MLIR verification, but let's be safe.
op.emitOpError("Mismatched shapes in elementwise op.");
return failure();
}
auto element_type = ty.getElementType();
// PowFOp and DivFOp do not seem to be supported in bf16 on later
// hardware.
bool needs_cast = hardware_generation_ <= 5 || isa<math::PowFOp>(op) ||
isa<arith::DivFOp>(op);
if (needs_cast && element_type.isBF16()) {
auto target_f32 =
builder.create<arith::ExtFOp>(op.getLoc(), target_f32_ty, operand)
.getResult();
should_rewrite_op = true;
new_operands.push_back(target_f32);
} else {
new_operands.push_back(operand);
}
} else {
// Should already be checked my MLIR verification, but let's be safe.
op.emitOpError("MLIR unsupported - mix scalar and vec elementwise ops");
return failure();
}
}
if (should_rewrite_op) {
auto result_ty = dyn_cast<VectorType>(op.getResult(0).getType());
if (!result_ty) {
op.emitOpError("Not implemented: Unexpected result type");
return failure();
}
auto result_element_type = result_ty.getElementType();
if (!result_element_type.isF32() && !result_element_type.isBF16()) {
op.emitOpError("Not implemented: Unexpected result element type");
return failure();
}
// Do the new op in f32, then truncate to the original element type.
auto new_op = builder.create(op.getLoc(), op.getName().getIdentifier(),
new_operands, target_f32_ty);
new_op = builder.create<arith::TruncFOp>(op.getLoc(), res_ty,
new_op->getResult(0));
op.replaceAllUsesWith(new_op);
op.erase();
}
return success();
}

LogicalResult canonicalize_matmul(Operation &op) {
auto matmul_op = dyn_cast<tpu::MatmulOp>(op);
if (!matmul_op) {
Expand Down Expand Up @@ -196,9 +272,23 @@ const llvm::StringMap<canonicalize_rule_type> &rules() {
return *rules;
}

const llvm::StringSet<> &elementwise_convertible_ops() {
static auto ops = new llvm::StringSet<>{arith::MulFOp::getOperationName(),
arith::DivFOp::getOperationName(),
arith::AddFOp::getOperationName(),
arith::SubFOp::getOperationName(),
arith::MaximumFOp::getOperationName(),
arith::MinimumFOp::getOperationName(),
math::PowFOp::getOperationName()};
return *ops;
}

class MosaicCanonicalizer {
public:
MosaicCanonicalizer() {}
MosaicCanonicalizer(int hardware_generation)
: hardware_generation_(hardware_generation) {}

int hardware_generation_;

LogicalResult canonicalize(func::FuncOp op) {
if (!op.getBody().hasOneBlock()) {
Expand Down Expand Up @@ -229,6 +319,10 @@ class MosaicCanonicalizer {
}
}
}
if (elementwise_convertible_ops().contains(
any_op.getName().getStringRef())) {
return canonicalize_elementwise(hardware_generation_, any_op);
}
if (auto rule_it = rules().find(any_op.getName().getStringRef());
rule_it != rules().end()) {
const canonicalize_rule_type &rule = rule_it->getValue();
Expand All @@ -240,19 +334,23 @@ class MosaicCanonicalizer {

struct CanonicalizeMosaicPass
: public impl::CanonicalizeMosaicPassBase<CanonicalizeMosaicPass> {
CanonicalizeMosaicPass() {}
CanonicalizeMosaicPass(int hardware_generation)
: hardware_generation_(hardware_generation) {}

int hardware_generation_;

void runOnOperation() override {
func::FuncOp func = getOperation();
MosaicCanonicalizer vlc;
MosaicCanonicalizer vlc(hardware_generation_);
if (vlc.canonicalize(func).failed()) {
signalPassFailure();
}
};
};

std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass() {
return std::make_unique<CanonicalizeMosaicPass>();
std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
int hardware_generation) {
return std::make_unique<CanonicalizeMosaicPass>(hardware_generation);
}

} // namespace mlir::tpu

0 comments on commit 764ec92

Please sign in to comment.