Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Passes/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms
EnzymeBatchToTensorPass.cpp
EnzymeWrapPass.cpp
InlineEnzymeRegions.cpp
HoistEnzymeRegions.cpp
LowerLLVMExtPass.cpp
PrintActivityAnalysis.cpp
PrintAliasAnalysis.cpp
Expand Down
156 changes: 156 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/HoistEnzymeRegions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
//===- HoistEnzymeRegions.cpp - Invariant code motion ------------===//
//===- within enzyme.autodiff_region ----------=== //
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements passes to hoist computations within autodiff_region ops
// to the caller
//
//===----------------------------------------------------------------------===//

#include "Dialect/Ops.h"
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/Utils.h"
#include "Passes/Passes.h"

#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;

namespace mlir {
namespace enzyme {
#define GEN_PASS_DEF_HOISTENZYMEFROMREGIONPASS
#include "Passes/Passes.h.inc"
} // namespace enzyme
} // namespace mlir

#define DEBUG_TYPE "enzyme-hoist"
#define ENZYME_DBGS llvm::dbgs() << "[" << DEBUG_TYPE << "]"

namespace {

static bool checkRangeDominance(enzyme::AutoDiffRegionOp &rootOp,
SetVector<Operation *> &specialOps,
ValueRange values) {
DominanceInfo dominance(rootOp);
for (auto value : values) {
if (dominance.properlyDominates(value, rootOp))
continue;
// Block arguments within autodiff_region are not supported
// TODO: add support for enzyme_const block arguments
if (isa<BlockArgument>(value)) {
return false;
}
if (!llvm::is_contained(specialOps, value.getDefiningOp())) {
return false;
}
}
return true;
}

struct HoistEnzymeAutoDiff : public OpRewritePattern<enzyme::AutoDiffRegionOp> {
using OpRewritePattern<enzyme::AutoDiffRegionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(enzyme::AutoDiffRegionOp rootOp,
PatternRewriter &rewriter) const override {
Region &autodiffRegion = rootOp.getBody();
SmallVector<Value> primalArgs = rootOp.getPrimalInputs();
SmallVector<Value> regionPrimalArgs(autodiffRegion.getArguments());

if (primalArgs.size() != regionPrimalArgs.size())
return failure();

llvm::SetVector<Value> freeValues;
getUsedValuesDefinedAbove(autodiffRegion, freeValues);

for (Value value : freeValues) {
for (auto [pval, bval] : llvm::zip(primalArgs, regionPrimalArgs)) {
if (value == pval) {
for (OpOperand &use : llvm::make_early_inc_range(value.getUses())) {
if (rootOp->isProperAncestor(use.getOwner()))
use.assign(bval);
}
}
}
}

llvm::SetVector<Operation *> liftOps;
llvm::SetVector<Operation *> stationaryOps;
llvm::SmallVector<MemoryEffects::EffectInstance> stationaryEffects;
for (Block &blk : autodiffRegion.getBlocks()) {
for (Operation &bodyOp : blk.without_terminator()) {
bool canLift = true;
llvm::SmallVector<MemoryEffects::EffectInstance> bodyOpEffects;
bool couldCollectEffects =
enzyme::oputils::collectOpEffects(&bodyOp, bodyOpEffects);

if (!couldCollectEffects)
canLift = false;

canLift = checkRangeDominance(rootOp, liftOps, bodyOp.getOperands());

if (bodyOp.getNumRegions()) {
canLift = false;
llvm::SetVector<Value> values;
getUsedValuesDefinedAbove(bodyOp.getRegions(), values);
canLift = checkRangeDominance(rootOp, liftOps, values.getArrayRef());
}

// Check for memory conflicts with current set of stationary ops
for (auto stationaryEffect : stationaryEffects) {
for (auto bodyOpEffect : bodyOpEffects) {
if ((isa<MemoryEffects::Write>(stationaryEffect.getEffect()) &&
isa<MemoryEffects::Read>(bodyOpEffect.getEffect())) ||
(isa<MemoryEffects::Read>(stationaryEffect.getEffect()) &&
isa<MemoryEffects::Write>(bodyOpEffect.getEffect())) ||
(isa<MemoryEffects::Write>(stationaryEffect.getEffect()) &&
isa<MemoryEffects::Write>(bodyOpEffect.getEffect()))) {

if (enzyme::oputils::mayAlias(bodyOpEffect, stationaryEffect)) {
canLift = false;
break;
}
}
}
}

if (canLift) {
liftOps.insert(&bodyOp);
} else {
stationaryOps.insert(&bodyOp);
stationaryEffects.append(bodyOpEffects.begin(), bodyOpEffects.end());
}
}
}

// Lift operations
for (Operation *op : llvm::make_early_inc_range(liftOps)) {
rewriter.moveOpBefore(op, rootOp);
}

return success();
}
};

struct HoistEnzymeFromRegion
: public enzyme::impl::HoistEnzymeFromRegionPassBase<
HoistEnzymeFromRegion> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<HoistEnzymeAutoDiff>(&getContext());
GreedyRewriteConfig config;
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
};
} // namespace
7 changes: 7 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,13 @@ def OutlineEnzymeFromRegionPass : Pass<"outline-enzyme-regions"> {
];
}

def HoistEnzymeFromRegionPass : Pass<"hoist-enzyme-regions"> {
let summary = "Hoist independent primal ops out of autodiff_region.";
let dependentDialects = [
"enzyme::EnzymeDialect",
];
}

def EnzymeOpsToMemRefPass : Pass<"convert-enzyme-to-memref"> {
let summary = "Lower custom Enzyme ops to the MemRef dialect";
let dependentDialects = [
Expand Down
1 change: 1 addition & 0 deletions enzyme/test/MLIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_subdirectory(ActivityAnalysis)
add_subdirectory(AliasAnalysis)
add_subdirectory(Batch)
add_subdirectory(ForwardMode)
add_subdirectory(OptimizeAD)
add_subdirectory(Passes)
add_subdirectory(ProbProg)
add_subdirectory(ReverseMode)
Expand Down
8 changes: 8 additions & 0 deletions enzyme/test/MLIR/OptimizeAD/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Run regression and unit tests
add_lit_testsuite(check-enzymemlir-optimizead "Running MLIR OptimizeAD tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS enzymemlir-opt
ARGS -v
)

set_target_properties(check-enzymemlir-optimizead PROPERTIES FOLDER "Tests")
73 changes: 73 additions & 0 deletions enzyme/test/MLIR/OptimizeAD/hoist_region.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// RUN: enzymemlir-opt --split-input-file --hoist-enzyme-regions %s | FileCheck %s
// CHECK-LABEL: func.func @foo
// CHECK-SAME: (%arg0: f64, %arg1: f64, %arg2: f64) -> f64
// CHECK: %c10 = arith.constant 10 : index
// CHECK: %c1 = arith.constant 1 : index
// CHECK: %cst = arith.constant 2.500000e+00 : f64
// CHECK: %cst_0 = arith.constant 2.000000e+00 : f64
// CHECK: %cst_1 = arith.constant 0.000000e+00 : f64
// CHECK: %cst_2 = arith.constant 1.000000e+02 : f64
// CHECK: %0 = arith.mulf %arg2, %cst_0 : f64
// CHECK: %1 = scf.for %{{.*}} = %c1 to %c10 step %c1 iter_args(%{{.*}} = %cst) -> (f64) {
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %cst_2 : f64
// CHECK: %{{.*}} = scf.for %{{.*}} = %c1 to %c10 step %c1 iter_args(%{{.*}} = %{{.*}}) -> (f64) {
// CHECK: %{{.*}} = arith.addf %{{.*}}, %0 : f64
// CHECK: scf.yield %{{.*}} : f64
// CHECK: }
// CHECK: scf.yield %{{.*}} : f64
// CHECK: }
// CHECK: %2 = enzyme.autodiff_region(%arg0, %arg1) {
// CHECK: ^bb0(%{{.*}}: f64):
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %0 : f64
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %1 : f64
// CHECK: %{{.*}} = arith.addf %{{.*}}, %cst_1 : f64
// CHECK: %{{.*}} = scf.for %{{.*}} = %c1 to %c10 step %c1 iter_args(%{{.*}} = %{{.*}}) -> (f64) {
// CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64
// CHECK: scf.yield %{{.*}} : f64
// CHECK: }
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64
// CHECK: enzyme.yield %{{.*}} : f64
// CHECK: } attributes {{.*}} : (f64, f64) -> f64

func.func @foo(%arg0: f64, %arg1: f64,%xx: f64) -> f64 {

%yy_cst = arith.constant 100.0 : f64
%0 = enzyme.autodiff_region(%arg0, %arg1) {
^bb0(%arg2: f64):
// hoistable constant ops
%c0 = arith.constant 0.0 : f64
%c1 = arith.constant 1.0 : f64
%c2 = arith.constant 2.0 : f64
%cx = arith.mulf %c2, %xx : f64

%sq = arith.mulf %arg2, %arg2 : f64
%sqx = arith.mulf %sq, %cx : f64

// hoistable loops
%yy0 = arith.constant 2.5 : f64
%one = arith.constant 1 : index
%ten = arith.constant 10 : index
%yy = scf.for %iv = %one to %ten step %one iter_args(%yy_iter = %yy0) -> (f64) {
%tm = arith.mulf %yy_iter, %yy_cst : f64
%ta = scf.for %jv = %one to %ten step %one iter_args(%tm_iter = %tm) -> (f64) {
%ta = arith.addf %tm, %cx : f64
scf.yield %ta : f64
}
scf.yield %ta : f64
}

%sqxy = arith.mulf %sqx, %yy : f64
%zz0 = arith.addf %sqx, %c0 : f64
%zz = scf.for %iv = %one to %ten step %one iter_args(%zz_iter = %zz0) ->(f64) {
%zm = arith.addf %zz_iter, %sqx : f64
%zout = arith.mulf %zm, %zz_iter : f64
scf.yield %zout : f64
}

%sqxyz = arith.mulf %zz, %sqxy : f64
enzyme.yield %sqxyz : f64
} attributes {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
return %0 : f64
}
Loading