diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 51738dc82e86..60f040c95dc4 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -59,7 +59,8 @@ std::pair mightCommunicateBetweenChips(Operation* op); std::unique_ptr> createInferMemRefLayoutPass( int hardware_generation = -1, const TpuTilingFlags &tpu_tiling_flags = {}); -std::unique_ptr> createCanonicalizeMosaicPass(); +std::unique_ptr> createCanonicalizeMosaicPass( + int hardware_generation = -1); std::unique_ptr> createInferVectorLayoutPass( int lane_count = 128, int sublane_count = 8); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 606f29604fe7..0f38fb694f3e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -1,5 +1,6 @@ #include #include +#include #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -8,6 +9,8 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" // NOLINTNEXTLINE(misc-include-cleaner) #include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" @@ -19,6 +22,7 @@ #include "mlir/include/mlir/IR/Block.h" #include "mlir/include/mlir/IR/Builders.h" #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/Operation.h" #include "mlir/include/mlir/IR/Region.h" #include "mlir/include/mlir/IR/Value.h" @@ -111,6 +115,78 @@ LogicalResult tpu_matmul_rule(tpu::MatmulOp op) { return success(); }; +LogicalResult canonicalize_elementwise(int hardware_generation_, + Operation &op) { + OpBuilder builder(&op); + auto operands = op.getOperands(); + auto res_ty = dyn_cast(op.getResult(0).getType()); + if (op.getNumResults() != 1) { + op.emitOpError("Invariant violated: Unexpected number of results"); + return failure(); + } + if (!res_ty) { + // scalar + // TODO(mvoz): Add canonicalization and invariants for scalar elementwise + // ops. + return success(); + } + auto shape = res_ty.getShape(); + std::vector new_operands; + new_operands.reserve(operands.size()); + + bool should_rewrite_op = false; + auto target_f32_ty = VectorType::get(shape, builder.getF32Type()); + for (int i = 0; i < operands.size(); ++i) { + auto operand = operands[i]; + auto ty = dyn_cast(operand.getType()); + if (ty) { + if (ty.getShape() != shape) { + // Should already be checked my MLIR verification, but let's be safe. + op.emitOpError("Mismatched shapes in elementwise op."); + return failure(); + } + auto element_type = ty.getElementType(); + // PowFOp and DivFOp do not seem to be supported in bf16 on later + // hardware. + bool needs_cast = hardware_generation_ <= 5 || isa(op) || + isa(op); + if (needs_cast && element_type.isBF16()) { + auto target_f32 = + builder.create(op.getLoc(), target_f32_ty, operand) + .getResult(); + should_rewrite_op = true; + new_operands.push_back(target_f32); + } else { + new_operands.push_back(operand); + } + } else { + // Should already be checked my MLIR verification, but let's be safe. + op.emitOpError("MLIR unsupported - mix scalar and vec elementwise ops"); + return failure(); + } + } + if (should_rewrite_op) { + auto result_ty = dyn_cast(op.getResult(0).getType()); + if (!result_ty) { + op.emitOpError("Not implemented: Unexpected result type"); + return failure(); + } + auto result_element_type = result_ty.getElementType(); + if (!result_element_type.isF32() && !result_element_type.isBF16()) { + op.emitOpError("Not implemented: Unexpected result element type"); + return failure(); + } + // Do the new op in f32, then truncate to the original element type. + auto new_op = builder.create(op.getLoc(), op.getName().getIdentifier(), + new_operands, target_f32_ty); + new_op = builder.create(op.getLoc(), res_ty, + new_op->getResult(0)); + op.replaceAllUsesWith(new_op); + op.erase(); + } + return success(); +} + LogicalResult canonicalize_matmul(Operation &op) { auto matmul_op = dyn_cast(op); if (!matmul_op) { @@ -196,9 +272,23 @@ const llvm::StringMap &rules() { return *rules; } +const llvm::StringSet<> &elementwise_convertible_ops() { + static auto ops = new llvm::StringSet<>{arith::MulFOp::getOperationName(), + arith::DivFOp::getOperationName(), + arith::AddFOp::getOperationName(), + arith::SubFOp::getOperationName(), + arith::MaximumFOp::getOperationName(), + arith::MinimumFOp::getOperationName(), + math::PowFOp::getOperationName()}; + return *ops; +} + class MosaicCanonicalizer { public: - MosaicCanonicalizer() {} + MosaicCanonicalizer(int hardware_generation) + : hardware_generation_(hardware_generation) {} + + int hardware_generation_; LogicalResult canonicalize(func::FuncOp op) { if (!op.getBody().hasOneBlock()) { @@ -229,6 +319,10 @@ class MosaicCanonicalizer { } } } + if (elementwise_convertible_ops().contains( + any_op.getName().getStringRef())) { + return canonicalize_elementwise(hardware_generation_, any_op); + } if (auto rule_it = rules().find(any_op.getName().getStringRef()); rule_it != rules().end()) { const canonicalize_rule_type &rule = rule_it->getValue(); @@ -240,19 +334,23 @@ class MosaicCanonicalizer { struct CanonicalizeMosaicPass : public impl::CanonicalizeMosaicPassBase { - CanonicalizeMosaicPass() {} + CanonicalizeMosaicPass(int hardware_generation) + : hardware_generation_(hardware_generation) {} + + int hardware_generation_; void runOnOperation() override { func::FuncOp func = getOperation(); - MosaicCanonicalizer vlc; + MosaicCanonicalizer vlc(hardware_generation_); if (vlc.canonicalize(func).failed()) { signalPassFailure(); } }; }; -std::unique_ptr> createCanonicalizeMosaicPass() { - return std::make_unique(); +std::unique_ptr> createCanonicalizeMosaicPass( + int hardware_generation) { + return std::make_unique(hardware_generation); } } // namespace mlir::tpu