Skip to content

Commit f603104

Browse files
authored
Reduce-Mul-Broadcast to Dot General (#1669)
* Reduce-Mul-Broadcast to Dot General * unfmt
1 parent fadcc11 commit f603104

File tree

4 files changed

+108
-0
lines changed

4 files changed

+108
-0
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25877,6 +25877,90 @@ struct BinaryNegatedOperandsSimplify
2587725877
}
2587825878
};
2587925879

25880+
struct ReduceMulBroadcastToDotGeneral
25881+
: public CheckedOpRewritePattern<stablehlo::ReduceOp,
25882+
ReduceMulBroadcastToDotGeneral> {
25883+
using CheckedOpRewritePattern<
25884+
stablehlo::ReduceOp,
25885+
ReduceMulBroadcastToDotGeneral>::CheckedOpRewritePattern;
25886+
25887+
LogicalResult matchAndRewriteImpl(stablehlo::ReduceOp op,
25888+
PatternRewriter &rewriter) const {
25889+
if (op.getInputs().size() != 1 || op.getInitValues().size() != 1) {
25890+
return rewriter.notifyMatchFailure(
25891+
op, "only single-operand single-init reduce is supported");
25892+
}
25893+
25894+
auto dims = op.getDimensions();
25895+
25896+
Value input = op.getInputs()[0];
25897+
auto TT = cast<TensorType>(input.getType());
25898+
auto OT = cast<TensorType>(op.getResultTypes()[0]);
25899+
25900+
if (OT.getRank() != 2 || dims.size() != 1)
25901+
return failure();
25902+
25903+
auto checkCommonReduce = mlir::stablehlo::CheckCommonReduceOp(op);
25904+
if (!checkCommonReduce.isAddReduce ||
25905+
!matchPattern(op.getInitValues()[0], m_AnyZeroFloat()))
25906+
return rewriter.notifyMatchFailure(op, "reduction is not add");
25907+
25908+
auto mul = input.getDefiningOp<stablehlo::MulOp>();
25909+
if (!mul)
25910+
return rewriter.notifyMatchFailure(op, "input source is not a mul op");
25911+
25912+
Value mulLhs = mul.getLhs(), mulRhs = mul.getRhs();
25913+
auto lhsBdim = mulLhs.getDefiningOp<stablehlo::BroadcastInDimOp>(),
25914+
rhsBdim = mulRhs.getDefiningOp<stablehlo::BroadcastInDimOp>();
25915+
25916+
if (!lhsBdim || !rhsBdim)
25917+
return failure();
25918+
25919+
auto prepareInputForDotGeneral =
25920+
[&](stablehlo::BroadcastInDimOp bdim) -> Value {
25921+
// transpose dims: [0, 2] -> [0, 1]
25922+
// transpose dims: [1, 0] -> [1, 0]
25923+
auto OT = cast<TensorType>(bdim.getResult().getType());
25924+
25925+
auto bdims = bdim.getBroadcastDimensions();
25926+
SmallVector<int64_t> transposeDims(bdims.size(), -1);
25927+
25928+
int64_t ncdims = 0;
25929+
for (int i = 0; i < OT.getRank(); i++) {
25930+
bool inBDims = false;
25931+
for (auto [j, dim] : llvm::enumerate(bdims)) {
25932+
if (dim == i) {
25933+
inBDims = true;
25934+
transposeDims[j] = i - ncdims;
25935+
break;
25936+
}
25937+
}
25938+
if (!inBDims) {
25939+
ncdims++;
25940+
}
25941+
}
25942+
25943+
Value prepared = stablehlo::TransposeOp::create(
25944+
rewriter, bdim.getLoc(), bdim.getOperand(), transposeDims);
25945+
25946+
return prepared;
25947+
};
25948+
25949+
auto lhs = prepareInputForDotGeneral(lhsBdim);
25950+
auto rhs = prepareInputForDotGeneral(rhsBdim);
25951+
25952+
auto ndim = stablehlo::DotDimensionNumbersAttr::get(
25953+
op.getContext(), {}, {}, op.getDimensions(), op.getDimensions());
25954+
25955+
auto dg = DotGeneralOp::create(rewriter, op.getLoc(), OT, lhs, rhs, ndim,
25956+
/* precision_config */ nullptr,
25957+
/*algorithm*/ nullptr);
25958+
rewriter.replaceAllOpUsesWith(op, dg.getResult());
25959+
25960+
return success();
25961+
}
25962+
};
25963+
2588025964
// currently limited to non-batched dot_general
2588125965
struct DotGeneralToSyrk
2588225966
: public CheckedOpRewritePattern<stablehlo::DotGeneralOp,
@@ -26761,6 +26845,7 @@ struct EnzymeHLOOptPass
2676126845
ElementwiseWrap,
2676226846
ElementwiseExtend,
2676326847
SubtractMultiplyConstToAddMulConst,
26848+
ReduceMulBroadcastToDotGeneral,
2676426849
DotGeneralDistributiveSimplify<stablehlo::AddOp>,
2676526850
DotGeneralDistributiveSimplify<stablehlo::SubtractOp>,
2676626851
TrivialReduceWindowToReduceOp,

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2505,6 +2505,11 @@ def EnzymeHLOUnroll : EnzymeHLOParameterizedPatternOp<
25052505
}];
25062506
}
25072507

2508+
def ApplyReduceMulBroadcastToDotGeneralPatterns : EnzymeHLOPatternOp<
2509+
"reduce_mul_broadcast_to_dot_general"> {
2510+
let patterns = ["ReduceMulBroadcastToDotGeneral"];
2511+
}
2512+
25082513
def ApplyDotGeneralOnlyDiagonalAccessPatterns : EnzymeHLOPatternOp<
25092514
"dot_general_only_diagonal_access"> {
25102515
let patterns = ["DotGeneralOnlyDiagonalAccess"];

src/enzyme_ad/jax/primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def optimization_passes(
304304
"self_mul_to_convolution_like(0)",
305305
"trivial_reduce_window_to_reduce_op",
306306
"case_to_if",
307+
"reduce_mul_broadcast_to_dot_general",
307308
"dot_general_add_distributive_simplify",
308309
"dot_general_subtract_distributive_simplify",
309310
"remove_no_ops_from_while_loop",
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s
2+
3+
module {
4+
func.func @main(%arg0: tensor<100x100xf64>, %arg1: tensor<100x100xf64>) -> tensor<100x100xf64> {
5+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
6+
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 2] : (tensor<100x100xf64>) -> tensor<100x100x100xf64>
7+
%1 = stablehlo.broadcast_in_dim %arg1, dims = [1, 0] : (tensor<100x100xf64>) -> tensor<100x100x100xf64>
8+
%2 = stablehlo.multiply %0, %1 {enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>]} : tensor<100x100x100xf64>
9+
%3 = stablehlo.reduce(%2 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<100x100x100xf64>, tensor<f64>) -> tensor<100x100xf64>
10+
return %3 : tensor<100x100xf64>
11+
}
12+
}
13+
14+
// CHECK: func.func @main(%arg0: tensor<100x100xf64>, %arg1: tensor<100x100xf64>) -> tensor<100x100xf64> {
15+
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [0] x [1] : (tensor<100x100xf64>, tensor<100x100xf64>) -> tensor<100x100xf64>
16+
// CHECK-NEXT: return %0 : tensor<100x100xf64>
17+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)