@@ -25877,6 +25877,90 @@ struct BinaryNegatedOperandsSimplify
2587725877 }
2587825878};
2587925879
25880+ struct ReduceMulBroadcastToDotGeneral
25881+ : public CheckedOpRewritePattern<stablehlo::ReduceOp,
25882+ ReduceMulBroadcastToDotGeneral> {
25883+ using CheckedOpRewritePattern<
25884+ stablehlo::ReduceOp,
25885+ ReduceMulBroadcastToDotGeneral>::CheckedOpRewritePattern;
25886+
25887+ LogicalResult matchAndRewriteImpl(stablehlo::ReduceOp op,
25888+ PatternRewriter &rewriter) const {
25889+ if (op.getInputs().size() != 1 || op.getInitValues().size() != 1) {
25890+ return rewriter.notifyMatchFailure(
25891+ op, "only single-operand single-init reduce is supported");
25892+ }
25893+
25894+ auto dims = op.getDimensions();
25895+
25896+ Value input = op.getInputs()[0];
25897+ auto TT = cast<TensorType>(input.getType());
25898+ auto OT = cast<TensorType>(op.getResultTypes()[0]);
25899+
25900+ if (OT.getRank() != 2 || dims.size() != 1)
25901+ return failure();
25902+
25903+ auto checkCommonReduce = mlir::stablehlo::CheckCommonReduceOp(op);
25904+ if (!checkCommonReduce.isAddReduce ||
25905+ !matchPattern(op.getInitValues()[0], m_AnyZeroFloat()))
25906+ return rewriter.notifyMatchFailure(op, "reduction is not add");
25907+
25908+ auto mul = input.getDefiningOp<stablehlo::MulOp>();
25909+ if (!mul)
25910+ return rewriter.notifyMatchFailure(op, "input source is not a mul op");
25911+
25912+ Value mulLhs = mul.getLhs(), mulRhs = mul.getRhs();
25913+ auto lhsBdim = mulLhs.getDefiningOp<stablehlo::BroadcastInDimOp>(),
25914+ rhsBdim = mulRhs.getDefiningOp<stablehlo::BroadcastInDimOp>();
25915+
25916+ if (!lhsBdim || !rhsBdim)
25917+ return failure();
25918+
25919+ auto prepareInputForDotGeneral =
25920+ [&](stablehlo::BroadcastInDimOp bdim) -> Value {
25921+ // transpose dims: [0, 2] -> [0, 1]
25922+ // transpose dims: [1, 0] -> [1, 0]
25923+ auto OT = cast<TensorType>(bdim.getResult().getType());
25924+
25925+ auto bdims = bdim.getBroadcastDimensions();
25926+ SmallVector<int64_t> transposeDims(bdims.size(), -1);
25927+
25928+ int64_t ncdims = 0;
25929+ for (int i = 0; i < OT.getRank(); i++) {
25930+ bool inBDims = false;
25931+ for (auto [j, dim] : llvm::enumerate(bdims)) {
25932+ if (dim == i) {
25933+ inBDims = true;
25934+ transposeDims[j] = i - ncdims;
25935+ break;
25936+ }
25937+ }
25938+ if (!inBDims) {
25939+ ncdims++;
25940+ }
25941+ }
25942+
25943+ Value prepared = stablehlo::TransposeOp::create(
25944+ rewriter, bdim.getLoc(), bdim.getOperand(), transposeDims);
25945+
25946+ return prepared;
25947+ };
25948+
25949+ auto lhs = prepareInputForDotGeneral(lhsBdim);
25950+ auto rhs = prepareInputForDotGeneral(rhsBdim);
25951+
25952+ auto ndim = stablehlo::DotDimensionNumbersAttr::get(
25953+ op.getContext(), {}, {}, op.getDimensions(), op.getDimensions());
25954+
25955+ auto dg = DotGeneralOp::create(rewriter, op.getLoc(), OT, lhs, rhs, ndim,
25956+ /* precision_config */ nullptr,
25957+ /*algorithm*/ nullptr);
25958+ rewriter.replaceAllOpUsesWith(op, dg.getResult());
25959+
25960+ return success();
25961+ }
25962+ };
25963+
2588025964// currently limited to non-batched dot_general
2588125965struct DotGeneralToSyrk
2588225966 : public CheckedOpRewritePattern<stablehlo::DotGeneralOp,
@@ -26761,6 +26845,7 @@ struct EnzymeHLOOptPass
2676126845 ElementwiseWrap,
2676226846 ElementwiseExtend,
2676326847 SubtractMultiplyConstToAddMulConst,
26848+ ReduceMulBroadcastToDotGeneral,
2676426849 DotGeneralDistributiveSimplify<stablehlo::AddOp>,
2676526850 DotGeneralDistributiveSimplify<stablehlo::SubtractOp>,
2676626851 TrivialReduceWindowToReduceOp,
0 commit comments