From 99bd84317d5b961980430f1437645dec395d5f1f Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Wed, 15 Oct 2025 13:59:22 +0200 Subject: [PATCH 01/26] split mode --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 31 +++++++++++++++++++ enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 30 +++++++++++++++++++ enzyme/deferred.mlir | 40 +++++++++++++++++++++++++ 3 files changed, 101 insertions(+) create mode 100644 enzyme/deferred.mlir diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 84f72f490c9..bd2e3149776 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -206,6 +206,28 @@ def AutoDiffOp : Enzyme_Op<"autodiff", }]; } +def AutoDiffDeferredPrimalOp : Enzyme_Op<"autodiff_deferred_primal", + [DeclareOpInterfaceMethods]> { + let summary = "Runs an augmented primal that can later be used to generate reverse with enzyme.autodiff_deferred_reverse"; + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); + let results = (outs Variadic:$results, AnyType:$tape); + + let assemblyFormat = [{ + $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; +} + +def AutoDiffDeferredReverseOp : Enzyme_Op<"autodiff_deferred_reverse", + [DeclareOpInterfaceMethods]> { + let summary = "Runs the reverse from an enzyme.autodiff_deferred_primal result"; + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ + $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; +} + def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]> { let summary = "Perform reverse mode AD on a child region"; let arguments = (ins Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width, DefaultValuedAttr:$strong_zero, OptionalAttr:$fn); @@ -318,6 +340,15 @@ def Cache : Enzyme_Type<"Cache"> { let assemblyFormat = "`<` $type `>`"; } +def Tape : Enzyme_Type<"Tape"> { + let summary = "Tape for reverse deferred"; + let description = [{ + "Tape for reverse deferred" + }]; + let parameters = (ins); + let mnemonic = "Tape"; +} + def Gradient : Enzyme_Type<"Gradient"> { let summary = "Mutable storage for accumulating gradients"; let description = [{ diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index eaab5dbb05b..91cbfe07d8e 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -526,6 +526,36 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +//===----------------------------------------------------------------------===// +// AutoDiffDeferredOp +//===----------------------------------------------------------------------===// + +LogicalResult +AutoDiffDeferredPrimalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // TODO: Verify that the result type is same as the type of the referenced + // func.func op. + auto global = + symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); + if (!global) + return emitOpError("'") + << getFn() << "' does not reference a valid global funcOp"; + + return success(); +} + +LogicalResult +AutoDiffDeferredReverseOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // TODO: Verify that the result type is same as the type of the referenced + // func.func op. + auto global = + symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); + if (!global) + return emitOpError("'") + << getFn() << "' does not reference a valid global funcOp"; + + return success(); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/enzyme/deferred.mlir b/enzyme/deferred.mlir new file mode 100644 index 00000000000..b4776add700 --- /dev/null +++ b/enzyme/deferred.mlir @@ -0,0 +1,40 @@ +module { + + func.func @mul(%a: f32, %b: f32) -> f32 { + %0 = arith.mulf %a, %b : f32 + return %0 : f32 + } + + func.func @main() { + %a = arith.constant 1.0 : f32 + %b = arith.constant 1.0 : f32 + + %r, %tape = enzyme.autodiff_deferred_primal @mul(%a, %b) { + activity=[#enzyme, #enzyme], + ret_activity=[#enzyme] + } : (f32, f32) -> (f32, !enzyme.Tape) + + // --- + + %dres = arith.constant 1.0 : f32 + %da, %db = enzyme.autodiff_deferred_reverse @mul(%tape, %dres) { + activity=[#enzyme, #enzyme], + ret_activity=[#enzyme] + } : (!enzyme.Tape, f32) -> (f32, f32) + + return + } + + func.func @mul_primal(%a: f32, %b: f32) -> (f32, !enzyme.Tape) { + %cache0 = "enzyme.init"() : () -> !enzyme.Cache + %cache1 = "enzyme.init"() : () -> !enzyme.Cache + + %0 = arith.mulf %a, %b : f32 + "enzyme.push"(%cache0, %a) : (!enzyme.Cache, f32) -> () + "enzyme.push"(%cache1, %a) : (!enzyme.Cache, f32) -> () + + %tape = "enzyme_.new_tape"(%cache0, %cache1) : (!enzyme.Cache, !enzyme.Cache) -> !enzyme.Tape + return %0, %tape : f32, !enzyme.Tape + } + +} From f6a3ddcb1848f535b0b848f9e7eebad730d3e776 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Wed, 29 Oct 2025 11:46:16 +0100 Subject: [PATCH 02/26] wip: Split mode and custom reverse rules --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 62 ++++++- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h | 10 + .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 162 ++++++++++++++++ enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 173 +++++++++++++++++- 4 files changed, 403 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index bd2e3149776..bd906acb811 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -209,7 +209,7 @@ def AutoDiffOp : Enzyme_Op<"autodiff", def AutoDiffDeferredPrimalOp : Enzyme_Op<"autodiff_deferred_primal", [DeclareOpInterfaceMethods]> { let summary = "Runs an augmented primal that can later be used to generate reverse with enzyme.autodiff_deferred_reverse"; - let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width, DefaultValuedAttr:$strong_zero); let results = (outs Variadic:$results, AnyType:$tape); let assemblyFormat = [{ @@ -228,6 +228,62 @@ def AutoDiffDeferredReverseOp : Enzyme_Op<"autodiff_deferred_reverse", }]; } +def CustomReverseRuleOp : Enzyme_Op<"custom_reverse_rule", [IsolatedFromAbove]> { + let summary = "Parent operation for custom reverse rule declaration."; + let arguments = (ins FlatSymbolRefAttr:$name, TypeAttrOf:$function_type, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); + let regions = (region AnyRegion:$body); + let results = (outs); + + let assemblyFormat = [{ + $name $body attr-dict-with-keyword + }]; +} + +def AugmentedPrimalOp : Enzyme_Op<"augmented_primal", [HasParent<"CustomReverseRuleOp">]> { + let summary = "Defines the augmented primal for a custom reverse rule"; + let arguments = (ins); + let regions = (region AnyRegion:$body); + let results = (outs); + + let assemblyFormat = [{ + $body attr-dict-with-keyword + }]; +} + +def ReverseOp : Enzyme_Op<"reverse", [HasParent<"CustomReverseRuleOp">]> { + let summary = "Defined the reverse for a custom rule."; + let arguments = (ins); + let regions = (region AnyRegion:$body); + let results = (outs); + let assemblyFormat = [{ + $body attr-dict-with-keyword + }]; +} + +def CallAugmentedPrimalOp : Enzyme_Op<"call_augmented_primal", + // [DeclareOpInterfaceMethods]> { + []> { + let summary = ""; + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs); + let results = (outs Variadic:$outputs); + + let assemblyFormat = [{ + $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; +} + +def CallCustomReverseOp : Enzyme_Op<"call_custom_reverse", + // [DeclareOpInterfaceMethods]> { + []> { + let summary = ""; + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs); + let results = (outs Variadic:$outputs); + + let assemblyFormat = [{ + $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + }]; +} + def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]> { let summary = "Perform reverse mode AD on a child region"; let arguments = (ins Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width, DefaultValuedAttr:$strong_zero, OptionalAttr:$fn); @@ -266,8 +322,8 @@ def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]> } def YieldOp : Enzyme_Op<"yield", [Pure, ReturnLike, Terminator, - ParentOneOf<["AutoDiffRegionOp", "LoopOp"]>]> { - let summary = "Yield values at the end of an autodiff_region or loop op"; + ParentOneOf<["AutoDiffRegionOp", "LoopOp", "ReverseOp", "AugmentedPrimalOp", "CustomReverseRuleOp"]>]> { + let summary = "Yield values at the end of an autodiff_region, loop op, reverse op, aug primal op or custom reverse rule op"; let arguments = (ins Variadic:$operands); let assemblyFormat = [{ attr-dict ($operands^ `:` type($operands))? diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index af4867cdd71..76d12b794ab 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -234,6 +234,16 @@ class MEnzymeLogic { void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero); + FlatSymbolRefAttr + CreateSplitModeDiff(FunctionOpInterface fn, std::vector retType, + std::vector constants, MTypeAnalysis &TA, + std::vector returnPrimals, + std::vector returnShadows, DerivativeMode mode, + bool freeMemory, size_t width, mlir::Type addedType, + MFnTypeInfo type_args, std::vector volatile_args, + void *augmented, bool omp, llvm::StringRef postpasses, + bool verifyPostPasses, bool strongZero); + void initializeShadowValues(SmallVector &dominatorToposortBlocks, MGradientUtilsReverse *gutils); diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index caf681a084e..bf747c81052 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -1,3 +1,4 @@ +#include "CloneFunction.h" #include "Dialect/Ops.h" #include "Interfaces/AutoDiffOpInterface.h" #include "Interfaces/AutoDiffTypeInterface.h" @@ -305,3 +306,164 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( return nf; } + +FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( + FunctionOpInterface fn, std::vector retType, + std::vector constants, MTypeAnalysis &TA, + std::vector returnPrimals, std::vector returnShadows, + DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, + MFnTypeInfo type_args, std::vector volatile_args, void *augmented, + bool omp, llvm::StringRef postpasses, bool verifyPostPasses, + bool strongZero) { + + SymbolTable symbolTable(SymbolTable::getNearestSymbolTable(fn)); + + IRMapping originalToNew; + std::map originalToNewOps; + + SmallPtrSet returnvals; + SmallPtrSet constant_values; + SmallPtrSet nonconstant_values; + for (auto &&[arg, act] : + llvm::zip(fn.getFunctionBody().getArguments(), constants)) { + if (act == DIFFE_TYPE::CONSTANT) + constant_values.insert(arg); + else + nonconstant_values.insert(arg); + } + IRMapping invertedPointers; + + SmallVector returnPrimalsP(returnPrimals.begin(), returnPrimals.end()); + SmallVector returnShadowsP(returnShadows.begin(), returnShadows.end()); + + auto name = fn.getName(); + + SmallVector argActivityAttrs; + for (auto act : constants) + argActivityAttrs.push_back(mlir::enzyme::ActivityAttr::get( + fn.getContext(), mlir::enzyme::Activity::enzyme_active)); + + SmallVector retActivityAttrs; + for (auto act : constants) + retActivityAttrs.push_back(mlir::enzyme::ActivityAttr::get( + fn.getContext(), mlir::enzyme::Activity::enzyme_active)); + + auto argActivityAttr = ArrayAttr::get(fn.getContext(), argActivityAttrs); + auto retActivityAttr = ArrayAttr::get(fn.getContext(), retActivityAttrs); + + auto customRuleName = name + "_reverse_rule"; + SmallVector nameBuf; + + auto ruleNameAttr = FlatSymbolRefAttr::get( + fn.getContext(), customRuleName.toStringRef(nameBuf)); + + OpBuilder builder(fn); + auto customRule = enzyme::CustomReverseRuleOp::create( + builder, fn.getLoc(), ruleNameAttr, TypeAttr::get(fn.getFunctionType()), + argActivityAttr, retActivityAttr); + + Block *ruleBody = new Block(); + customRule.getBody().push_back(ruleBody); + + OpBuilder ruleBuilder(ruleBody, ruleBody->begin()); + + auto reverse = ruleBuilder.create(fn.getLoc()); + ruleBuilder.create(fn.getLoc(), ValueRange{}); + + ruleBuilder.setInsertionPoint(reverse); + + // FunctionOpInterface newF = + // cast(fn->cloneWithoutRegions()); + // SymbolTable::setSymbolName( + // newF, StringAttr::get(fn->getContext(), name + "_primal")); + + // FunctionOpInterface newFRev = + // cast(fn->cloneWithoutRegions()); + // SymbolTable::setSymbolName( + // newFRev, StringAttr::get(fn->getContext(), name + "_reverse")); + // + // cloneInto(&fn.getFunctionBody(), &newF.getFunctionBody(), originalToNew, + // originalToNewOps); + // + // llvm::errs() << "fn = " << newF << "\n"; + + auto newFunc = cast(fn->cloneWithoutRegions()); + cloneInto(&fn.getFunctionBody(), &newFunc.getFunctionBody(), originalToNew, + originalToNewOps); + + llvm::errs() << "new func = " << newFunc << "\n"; + + MGradientUtilsReverse *gutils = new MGradientUtilsReverse( + *this, newFunc, fn, TA, invertedPointers, returnPrimalsP, returnShadowsP, + constant_values, nonconstant_values, retType, constants, originalToNew, + originalToNewOps, mode, width, omp, postpasses, verifyPostPasses, + strongZero); + + gutils->createReverseModeBlocks(fn.getFunctionBody(), reverse.getBody()); + gutils->registerCacheCreatorHook([&](Type ty) -> std::pair { + Value cache = ruleBuilder.create(fn.getLoc(), ty); + return {cache, cache}; + }); + gutils->registerGradientCreatorHook([&](Location loc, Type ty) -> Value { + auto reverseEntry = &reverse.getBody().front(); + OpBuilder gBuilder(reverseEntry, reverseEntry->begin()); + return gBuilder.create(loc, ty); + }); + + bool valid = true; + for (auto &oBB : fn.getFunctionBody()) { + Block *newBB = gutils->getNewFromOriginal(&oBB); + Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(&oBB); + if (oBB.getNumSuccessors() == 0) { + Operation *oTerm = oBB.getTerminator(); + for (auto [res, act] : llvm::zip_equal(oTerm->getOperands(), retType)) { + if (act == DIFFE_TYPE::OUT_DIFF) { + OpBuilder diffeBuilder(reverseBB, reverseBB->begin()); + auto diffe = reverseBB->addArgument(res.getType(), res.getLoc()); + gutils->setDiffe(res, diffe, diffeBuilder); + } + } + } + + OpBuilder revBuilder(reverseBB, reverseBB->end()); + + auto first = oBB.rbegin(); + first++; + auto last = oBB.rend(); + for (auto it = first; it != last; ++it) { + Operation *op = &*it; + valid &= visitChild(op, revBuilder, gutils).succeeded(); + } + + if (oBB.isEntryBlock()) { + SmallVector toYield; + OpBuilder rBuilder(reverseBB, reverseBB->end()); + for (auto [act, arg] : llvm::zip_equal( + constants, fn.getFunctionBody().front().getArguments())) { + if (act == DIFFE_TYPE::OUT_DIFF) { + toYield.push_back(gutils->diffe(arg, rBuilder)); + } + } + rBuilder.create(fn.getLoc(), toYield); + } + } + + ruleBuilder.setInsertionPoint(reverse); + auto augmentedPrimal = + ruleBuilder.create(fn.getLoc()); + augmentedPrimal.getBody().takeBody(newFunc.getFunctionBody()); + for (Block &b : augmentedPrimal.getBody()) { + if (b.getNumSuccessors() == 0) { + Operation *term = b.getTerminator(); + OpBuilder builder(term); + builder.create(term->getLoc(), term->getOperands()); + term->erase(); + } + } + + delete gutils; + + newFunc->erase(); + + return ruleNameAttr; +} diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index ef3d7473a09..2fcd252f5f2 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -221,7 +221,7 @@ struct DifferentiatePass for (auto act : CI.getActivity()) { if (call_idx >= CI.getInputs().size()) { llvm::errs() << "Too few arguments to autodiff op" - << "CI: " << CI << "\n"; + << " CI: " << CI << "\n"; return failure(); } mlir::Value res = CI.getInputs()[call_idx]; @@ -350,6 +350,153 @@ struct DifferentiatePass return success(); } + LogicalResult HandleSplitModeAutoDiff(SymbolTableCollection &symbolTable, + enzyme::AutoDiffDeferredPrimalOp CI) { + auto tape = CI.getTape(); + + SmallVector reverseCalls; + for (auto user : tape.getUsers()) { + if (isa(user)) + reverseCalls.push_back(user); + else { + user->emitError() << "todo: unsupported tape usage"; + return failure(); + } + } + + auto &symbTable = + symbolTable.getSymbolTable(SymbolTable::getNearestSymbolTable(CI)); + + auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()); + auto fn = cast(symbolOp); + assert(fn); + if (CI.getActivity().size() != fn.getNumArguments()) { + llvm::errs() << "Incorrect number of argument activities on autodiff op" + << "CI: " << CI << ", expected " << fn.getNumArguments() + << " found " << CI.getActivity().size() << "\n"; + return failure(); + } + if (CI.getRetActivity().size() != fn.getNumResults()) { + llvm::errs() << "Incorrect number of result activities on autodiff op" + << "CI: " << CI << ", expected " << fn.getNumResults() + << " found " << CI.getRetActivity().size() << "\n"; + return failure(); + } + + std::vector arg_activities; + SmallVector args; + + size_t call_idx = 0; + { + for (auto act : CI.getActivity()) { + if (call_idx >= CI.getInputs().size()) { + llvm::errs() << "Too few arguments to autodiff op" + << " CI: " << CI << "\n"; + return failure(); + } + mlir::Value res = CI.getInputs()[call_idx]; + ++call_idx; + + auto iattr = cast(act); + auto val = iattr.getValue(); + DIFFE_TYPE ty; + switch (val) { + case mlir::enzyme::Activity::enzyme_active: + ty = DIFFE_TYPE::OUT_DIFF; + break; + case mlir::enzyme::Activity::enzyme_dup: + ty = DIFFE_TYPE::DUP_ARG; + break; + case mlir::enzyme::Activity::enzyme_const: + ty = DIFFE_TYPE::CONSTANT; + break; + case mlir::enzyme::Activity::enzyme_dupnoneed: + ty = DIFFE_TYPE::DUP_NONEED; + break; + case mlir::enzyme::Activity::enzyme_activenoneed: + ty = DIFFE_TYPE::OUT_DIFF; + assert(0 && "unsupported arg activenoneed"); + break; + case mlir::enzyme::Activity::enzyme_constnoneed: + ty = DIFFE_TYPE::CONSTANT; + assert(0 && "unsupported arg constnoneed"); + break; + } + arg_activities.push_back(ty); + args.push_back(res); + } + } + + bool omp = false; + auto mode = DerivativeMode::ReverseModeCombined; + std::vector retType; + std::vector returnPrimals; + std::vector returnShadows; + + // Add the return gradient + for (auto act : CI.getRetActivity()) { + auto iattr = cast(act); + auto val = iattr.getValue(); + DIFFE_TYPE ty; + bool primalNeeded = true; + switch (val) { + case mlir::enzyme::Activity::enzyme_active: + ty = DIFFE_TYPE::OUT_DIFF; + break; + case mlir::enzyme::Activity::enzyme_dup: + ty = DIFFE_TYPE::DUP_ARG; + break; + case mlir::enzyme::Activity::enzyme_const: + ty = DIFFE_TYPE::CONSTANT; + break; + case mlir::enzyme::Activity::enzyme_dupnoneed: + ty = DIFFE_TYPE::DUP_NONEED; + primalNeeded = false; + break; + case mlir::enzyme::Activity::enzyme_activenoneed: + ty = DIFFE_TYPE::OUT_DIFF; + primalNeeded = false; + break; + case mlir::enzyme::Activity::enzyme_constnoneed: + ty = DIFFE_TYPE::CONSTANT; + primalNeeded = false; + break; + } + retType.push_back(ty); + returnPrimals.push_back(primalNeeded); + returnShadows.push_back(false); + } + + std::vector volatile_args( + fn.getNumArguments(), !(mode == DerivativeMode::ReverseModeCombined)); + + MTypeAnalysis TA; + auto type_args = TA.getAnalyzedTypeInfo(fn); + bool freeMemory = true; + size_t width = CI.getWidth(); + + auto ruleToCall = Logic.CreateSplitModeDiff( + fn, retType, arg_activities, TA, returnPrimals, returnShadows, mode, + freeMemory, width, + /*addedType*/ nullptr, type_args, volatile_args, + /*augmented*/ nullptr, omp, postpasses, verifyPostPasses, + CI.getStrongZero()); + + OpBuilder builder(CI); + auto primalCall = builder.create( + CI.getLoc(), CI->getResultTypes(), ruleToCall, CI.getOperands()); + for (auto [oldRes, newRes] : + llvm::zip_equal(CI->getResults(), primalCall.getResults())) { + oldRes.replaceAllUsesWith(newRes); + } + + CI->erase(); + + // TODO: track tape usage to replace split_mode_reverse with call_custom_reverse + + return success(); + } + void lowerEnzymeCalls(SymbolTableCollection &symbolTable, FunctionOpInterface op) { { @@ -399,6 +546,30 @@ struct DifferentiatePass } } } + + { + SmallVector toLower; + op->walk([&](enzyme::AutoDiffDeferredPrimalOp dop) { + auto *symbolOp = + symbolTable.lookupNearestSymbolFrom(dop, dop.getFnAttr()); + auto callableOp = cast(symbolOp); + + lowerEnzymeCalls(symbolTable, callableOp); + toLower.push_back(dop); + }); + + for (auto T : toLower) { + if (auto F = dyn_cast(T)) { + auto res = HandleSplitModeAutoDiff(symbolTable, F); + if (!res.succeeded()) { + signalPassFailure(); + return; + } + } else { + llvm_unreachable("Illegal type"); + } + } + } }; }; From 8e34471843ef94fa5419a5af6eb0eec735a7b3ac Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Wed, 29 Oct 2025 11:46:37 +0100 Subject: [PATCH 03/26] wip: lower custom rule to func. --- .../LowerEnzymeCustomRulesToFuncPass.cpp | 289 ++++++++++++++++++ enzyme/Enzyme/MLIR/Passes/Passes.td | 8 + 2 files changed, 297 insertions(+) create mode 100644 enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp diff --git a/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp new file mode 100644 index 00000000000..f45698aee2f --- /dev/null +++ b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp @@ -0,0 +1,289 @@ +//===- LowerEnzymeCustomRulesToFuncPass.cpp - ------------------------------- // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "Dialect/Ops.h" +#include "Passes/Passes.h" +#include "Passes/RemovalUtils.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/PassManager.h" + +#define DEBUG_TYPE "enzyme" + +using namespace mlir; +using namespace mlir::enzyme; +using namespace enzyme; + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_LOWERENZYMECUSTOMRULESTOFUNCPASS +#include "Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +namespace { +struct LowerEnzymeCustomRulesToFuncPass + : public enzyme::impl::LowerEnzymeCustomRulesToFuncPassBase< + LowerEnzymeCustomRulesToFuncPass> { + using LowerEnzymeCustomRulesToFuncPassBase:: + LowerEnzymeCustomRulesToFuncPassBase; + + void runOnOperation() override; +}; +} // end anonymous namespace + +static LogicalResult +lowerCustomReverseRuleToFunc(enzyme::CustomReverseRuleOp revRule) { + SymbolTable symbolTable(SymbolTable::getNearestSymbolTable(revRule)); + + Block *bodyDef = &revRule.getBody().front(); + + enzyme::AugmentedPrimalOp primal = nullptr; + enzyme::ReverseOp reverse = nullptr; + + for (Operation &op : *bodyDef) { + if (auto AP = dyn_cast(op)) { + if (primal) { + AP->emitError() << "multiple augmented primal ops in a custom rule"; + return failure(); + } + primal = AP; + } else if (auto RO = dyn_cast(op)) { + if (reverse) { + RO->emitError() << "multiple reverse op in a custom rule"; + return failure(); + } + reverse = RO; + } + } + + bool singleBlock = + primal.getBody().hasOneBlock() && reverse.getBody().hasOneBlock(); + if (!singleBlock) { + // TODO: caching with non-structured control flow; + revRule->emitError() << "todo: lowering to func.func is not supported for " + "custom rules with more than one block."; + return failure(); + } + + auto funcType = revRule.getFunctionType(); + + SmallVector primalArgTypes(funcType.getInputs().begin(), + funcType.getInputs().end()); + SmallVector primalResultTypes(funcType.getResults().begin(), + funcType.getResults().end()); + + SmallVector reverseArgTypes; + for (auto [retTy, act] : + llvm::zip_equal(funcType.getResults(), revRule.getRetActivity())) { + + auto iattr = cast(act); + switch (iattr.getValue()) { + case mlir::enzyme::Activity::enzyme_active: + case mlir::enzyme::Activity::enzyme_activenoneed: + reverseArgTypes.push_back(retTy); + break; + case mlir::enzyme::Activity::enzyme_const: + break; + default: + llvm_unreachable("todo"); + } + } + + SmallVector reverseResultTypes; + for (auto [argTy, act] : + llvm::zip_equal(funcType.getInputs(), revRule.getActivity())) { + + auto iattr = cast(act); + switch (iattr.getValue()) { + case mlir::enzyme::Activity::enzyme_active: + case mlir::enzyme::Activity::enzyme_activenoneed: + reverseResultTypes.push_back(argTy); + break; + case mlir::enzyme::Activity::enzyme_const: + break; + default: + llvm_unreachable("todo"); + } + } + + SmallVector caches; + SmallVector cacheTypes; + for (Operation &op : *bodyDef) { + if (auto init = dyn_cast(&op)) { + auto CT = dyn_cast(init.getType()); + if (!CT) + continue; + + CacheInfo info(init.getResult()); + if (info.pushOp->getBlock() != &primal.getBody().front()) { + info.pushOp->emitError() + << "push operation not hoisted to the top level."; + return failure(); + } + + if (info.popOp->getBlock() != &reverse.getBody().front()) { + info.popOp->emitError() + << "pop operation not hoisted to the top level."; + return failure(); + } + + auto ET = CT.getType(); + cacheTypes.push_back(ET); + caches.push_back(info); + } + } + + primalResultTypes.append(cacheTypes.begin(), cacheTypes.end()); + reverseArgTypes.append(cacheTypes.begin(), cacheTypes.end()); + + auto revRuleName = revRule.getName(); + + FunctionType primalFuncType = FunctionType::get( + revRule->getContext(), primalArgTypes, primalResultTypes); + + SmallVector nameBuf; + Twine primalName = revRuleName + "_primal"; + Twine reverseName = revRuleName + "_reverse"; + + auto primalFunc = + func::FuncOp::create(primal.getLoc(), primalName.toStringRef(nameBuf), + primalFuncType, ArrayRef()); + + nameBuf.clear(); + + FunctionType reverseFuncType = FunctionType::get( + revRule->getContext(), reverseArgTypes, reverseResultTypes); + auto reverseFunc = + func::FuncOp::create(reverse.getLoc(), reverseName.toStringRef(nameBuf), + reverseFuncType, ArrayRef()); + + primalFunc.getBody().takeBody(primal.getBody()); + for (Block &b : primalFunc.getBody()) { + Operation *term = b.getTerminator(); + if (isa(term)) { + OpBuilder builder(term); + SmallVector toReturn(term->getOperands().begin(), + term->getOperands().end()); + for (auto &info : caches) { + toReturn.push_back(info.pushOp.getValue()); + info.pushOp->erase(); + } + builder.create(term->getLoc(), toReturn); + term->erase(); + } + } + + reverseFunc.getBody().takeBody(reverse.getBody()); + SmallVector cacheLocs = llvm::map_to_vector( + caches, [](CacheInfo info) { return info.initOp->getLoc(); }); + for (auto [info, arg] : llvm::zip_equal( + caches, + reverseFunc.getBody().front().addArguments(cacheTypes, cacheLocs))) { + info.popOp.getResult().replaceAllUsesWith(arg); + info.popOp->erase(); + info.initOp->erase(); + } + for (Block &b : reverseFunc.getBody()) { + Operation *term = b.getTerminator(); + if (isa(term)) { + OpBuilder builder(term); + builder.create(term->getLoc(), term->getOperands()); + term->erase(); + } + } + + symbolTable.insert(primalFunc); + SymbolTable::setSymbolVisibility(primalFunc, + SymbolTable::Visibility::Private); + + symbolTable.insert(reverseFunc); + SymbolTable::setSymbolVisibility(reverseFunc, + SymbolTable::Visibility::Private); + + auto uses = SymbolTable::getSymbolUses( + StringAttr::get(revRule->getContext(), revRuleName), symbolTable.getOp()); + if (!uses) { + revRule->erase(); + return success(); + } + + SmallVector tapes; + + SetVector toDelete; + + for (auto use : *uses) { + Operation *user = use.getUser(); + auto CAP = dyn_cast(user); + if (!CAP) + continue; + + OpBuilder builder(CAP); + auto primalCall = builder.create(CAP.getLoc(), primalFunc, + CAP->getOperands()); + + auto tape = CAP->getResult(CAP->getNumResults() - 1); + for (auto tapeUser : tape.getUsers()) { + if (auto CCR = dyn_cast(tapeUser)) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(CCR); + SmallVector operands( + CCR->getOperands().slice(0, CCR->getNumOperands() - 1).begin(), + CCR->getOperands().slice(0, CCR->getNumOperands() - 1).end()); + operands.append( + primalCall.getResults() + .slice(revRule.getFunctionType().getNumResults(), caches.size()) + .begin(), + primalCall.getResults() + .slice(revRule.getFunctionType().getNumResults(), caches.size()) + .end()); + auto reverseCall = + builder.create(CCR.getLoc(), reverseFunc, operands); + for (auto [oldRes, newRes] : + llvm::zip(CCR.getResults(), reverseCall.getResults())) { + oldRes.replaceAllUsesWith(newRes); + } + + toDelete.insert(CAP); + toDelete.insert(CCR); + } else { + tapeUser->emitError() + << "todo: support tape going through this operation"; + return failure(); + } + } + } + + toDelete.insert(revRule); + + auto worklist = toDelete.takeVector(); + while (!worklist.empty()) { + Operation *op = worklist.back(); + op->erase(); + worklist.pop_back(); + } + + return success(); +} + +void LowerEnzymeCustomRulesToFuncPass::runOnOperation() { + bool failed = false; + + getOperation()->walk([&failed](enzyme::CustomReverseRuleOp revRule) { + failed |= lowerCustomReverseRuleToFunc(revRule).failed(); + }); + + if (failed) { + signalPassFailure(); + return; + } +} diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 27af977b29e..87496983515 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -212,6 +212,14 @@ def OutlineEnzymeFromRegionPass : Pass<"outline-enzyme-regions"> { ]; } +def LowerEnzymeCustomRulesToFuncPass : Pass<"lower-enzyme-custom-rules-to-func"> { + let summary = "Lower enzyme custom rules"; + let dependentDialects = [ + "func::FuncDialect", + "enzyme::EnzymeDialect" + ]; +} + def EnzymeOpsToMemRefPass : Pass<"convert-enzyme-to-memref"> { let summary = "Lower custom Enzyme ops to the MemRef dialect"; let dependentDialects = [ From 307c91d34c9d44431c1b4acdfb9d2e308caf1f07 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Wed, 29 Oct 2025 15:07:43 +0100 Subject: [PATCH 04/26] tape --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 17 ++-- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 2 +- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 28 +++++- enzyme/deferred.mlir | 92 +++++++++++++++++-- 4 files changed, 120 insertions(+), 19 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index bd906acb811..8c0a8495cf3 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -206,11 +206,13 @@ def AutoDiffOp : Enzyme_Op<"autodiff", }]; } +def AnyTape : Type($_self)">, "enzyme tape">; + def AutoDiffDeferredPrimalOp : Enzyme_Op<"autodiff_deferred_primal", [DeclareOpInterfaceMethods]> { let summary = "Runs an augmented primal that can later be used to generate reverse with enzyme.autodiff_deferred_reverse"; let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width, DefaultValuedAttr:$strong_zero); - let results = (outs Variadic:$results, AnyType:$tape); + let results = (outs Variadic:$outputs, AnyTape:$tape); let assemblyFormat = [{ $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) @@ -220,11 +222,12 @@ def AutoDiffDeferredPrimalOp : Enzyme_Op<"autodiff_deferred_primal", def AutoDiffDeferredReverseOp : Enzyme_Op<"autodiff_deferred_reverse", [DeclareOpInterfaceMethods]> { let summary = "Runs the reverse from an enzyme.autodiff_deferred_primal result"; - let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); - let results = (outs Variadic:$results); + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, AnyTape:$tape, + ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); + let results = (outs Variadic:$outputs); let assemblyFormat = [{ - $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + $fn `(` operands `)` attr-dict `:` functional-type(operands, results) }]; } @@ -265,7 +268,7 @@ def CallAugmentedPrimalOp : Enzyme_Op<"call_augmented_primal", []> { let summary = ""; let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs); - let results = (outs Variadic:$outputs); + let results = (outs Variadic:$outputs, AnyTape:$tape); let assemblyFormat = [{ $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) @@ -276,11 +279,11 @@ def CallCustomReverseOp : Enzyme_Op<"call_custom_reverse", // [DeclareOpInterfaceMethods]> { []> { let summary = ""; - let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs); + let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, AnyTape:$tape); let results = (outs Variadic:$outputs); let assemblyFormat = [{ - $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) + $fn `(` operands `)` attr-dict `:` functional-type(operands, results) }]; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index bf747c81052..faf81772dfe 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -344,7 +344,7 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( fn.getContext(), mlir::enzyme::Activity::enzyme_active)); SmallVector retActivityAttrs; - for (auto act : constants) + for (auto act : retType) retActivityAttrs.push_back(mlir::enzyme::ActivityAttr::get( fn.getContext(), mlir::enzyme::Activity::enzyme_active)); diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index 2fcd252f5f2..26f4c6f0229 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -492,7 +492,33 @@ struct DifferentiatePass CI->erase(); - // TODO: track tape usage to replace split_mode_reverse with call_custom_reverse + SetVector toDelete; + + tape = primalCall.getTape(); + for (auto tapeUser : tape.getUsers()) { + if (auto revCall = + dyn_cast(tapeUser)) { + + OpBuilder builder(revCall); + auto newRevCall = builder.create( + revCall.getLoc(), revCall.getResultTypes(), ruleToCall, + revCall.getInputs(), tape); + revCall.replaceAllUsesWith(newRevCall.getResults()); + + toDelete.insert(revCall); + } else { + tapeUser->emitError() + << "todo: support tape going through this operation"; + return failure(); + } + } + + auto worklist = toDelete.takeVector(); + while (!worklist.empty()) { + Operation *op = worklist.back(); + op->erase(); + worklist.pop_back(); + } return success(); } diff --git a/enzyme/deferred.mlir b/enzyme/deferred.mlir index b4776add700..923bad3c6f6 100644 --- a/enzyme/deferred.mlir +++ b/enzyme/deferred.mlir @@ -5,7 +5,9 @@ module { return %0 : f32 } + // Split mode func.func @main() { + %a = arith.constant 1.0 : f32 %b = arith.constant 1.0 : f32 @@ -17,24 +19,94 @@ module { // --- %dres = arith.constant 1.0 : f32 - %da, %db = enzyme.autodiff_deferred_reverse @mul(%tape, %dres) { + %da, %db = enzyme.autodiff_deferred_reverse @mul(%dres, %tape) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme] - } : (!enzyme.Tape, f32) -> (f32, f32) + } : (f32, !enzyme.Tape) -> (f32, f32) return } - func.func @mul_primal(%a: f32, %b: f32) -> (f32, !enzyme.Tape) { - %cache0 = "enzyme.init"() : () -> !enzyme.Cache - %cache1 = "enzyme.init"() : () -> !enzyme.Cache + // if there is a single block in the augmented primal and reverse, then the push/pop + // can be replaced with just the values carried over from primal to reverse. + enzyme.custom_reverse_rule @reverse_f { + %0 = "enzyme.init"() : () -> !enzyme.Cache + %1 = "enzyme.init"() : () -> !enzyme.Cache - %0 = arith.mulf %a, %b : f32 - "enzyme.push"(%cache0, %a) : (!enzyme.Cache, f32) -> () - "enzyme.push"(%cache1, %a) : (!enzyme.Cache, f32) -> () + enzyme.augmented_primal { + ^bb0(%a: f32, %b: f32): + "enzyme.push"(%0, %a) : (!enzyme.Cache, f32) -> () + "enzyme.push"(%1, %b) : (!enzyme.Cache, f32) -> () + + %res = arith.mulf %a, %b : f32 + + enzyme.yield %res : f32 + } + + enzyme.reverse { + ^bb0(%dres: f32): + %a = "enzyme.pop"(%0) : (!enzyme.Cache) -> f32 + %b = "enzyme.pop"(%1) : (!enzyme.Cache) -> f32 - %tape = "enzyme_.new_tape"(%cache0, %cache1) : (!enzyme.Cache, !enzyme.Cache) -> !enzyme.Tape - return %0, %tape : f32, !enzyme.Tape + %da = arith.mulf %b, %dres : f32 + %db = arith.mulf %a, %dres : f32 + + enzyme.yield %da, %db : f32, f32 + } + + enzyme.yield + } attributes { + activity=[#enzyme, + #enzyme], + ret_activity=[#enzyme], + function_type = (f32, f32) -> (f32) } + func.func @custom_rule_call(%a: f32, %b: f32, %dres_in: f32) -> (f32, f32) { + %res, %tape = enzyme.call_augmented_primal @reverse_f(%a, %b) : (f32, f32) -> (f32, !enzyme.Tape) + + %res_g = "enzyme.init"() : () -> (!enzyme.Gradient) + %zero = arith.constant 0.0 : f32 + + "enzyme.set"(%res_g, %dres_in) : (!enzyme.Gradient, f32) -> () + %dres = "enzyme.get"(%res_g) : (!enzyme.Gradient) -> f32 + %da, %db = enzyme.call_custom_reverse @reverse_f(%dres, %tape) : (f32, !enzyme.Tape) -> (f32, f32) + return %da, %db : f32, f32 + } + + enzyme.custom_reverse_rule @exp_f32 { + %cache = "enzyme.init"() : () -> !enzyme.Cache + enzyme.augmented_primal { + ^bb0(%arg0: f32): + %res = math.exp %arg0 : f32 + "enzyme.push"(%cache, %res) : (!enzyme.Cache, f32) -> () + enzyme.yield %res : f32 + } + enzyme.reverse { + ^bb0(%dres: f32): + %res = "enzyme.pop"(%cache) : (!enzyme.Cache) -> (f32) + %darg0 = arith.mulf %dres, %res : f32 + enzyme.yield %darg0 : f32 + } + enzyme.yield + } attributes { + activity=[#enzyme], + ret_activity=[#enzyme], + function_type = (f32) -> (f32) + } + + // TODO: split mode is implemented using custom_reverse_rule + AD + + // func.func @mul_primal(%a: f32, %b: f32) -> (f32, !enzyme.Tape) { + // %cache0 = "enzyme.init"() : () -> !enzyme.Cache + // %cache1 = "enzyme.init"() : () -> !enzyme.Cache + + // %0 = arith.mulf %a, %b : f32 + // "enzyme.push"(%cache0, %a) : (!enzyme.Cache, f32) -> () + // "enzyme.push"(%cache1, %a) : (!enzyme.Cache, f32) -> () + + // %tape = "enzyme_.new_tape"(%cache0, %cache1) : (!enzyme.Cache, !enzyme.Cache) -> !enzyme.Tape + // return %0, %tape : f32, !enzyme.Tape + // } + } From 0beea7898ff0e16abf4fe1f0ea92151591ab841b Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Thu, 30 Oct 2025 09:37:36 +0100 Subject: [PATCH 05/26] Math derivatives --- enzyme/Enzyme/MLIR/Implementations/Common.td | 5 ++++ .../Implementations/ComplexDerivatives.td | 30 +++++++++++++++++++ .../MathAutoDiffOpInterfaceImpl.cpp | 1 + .../MLIR/Implementations/MathDerivatives.td | 14 ++++++--- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index c7c1f246774..eacdf52ea27 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -174,9 +174,14 @@ def CheckedDivF : SubRoutine<(Op $diffret, $x), def LlvmCheckedMulF : LlvmInst<"FMulOp">; def LlvmExpF : LlvmInst<"ExpOp">; +def ComplexCreate : ComplexInst<"CreateOp">; +def ComplexRe : ComplexInst<"ReOp">; +def ComplexIm : ComplexInst<"ImOp">; + def CosF : MathInst<"CosOp">; def SinF : MathInst<"SinOp">; def ExpF : MathInst<"ExpOp">; def SqrtF : MathInst<"SqrtOp">; +def AbsF : MathInst<"AbsFOp">; #endif // ENZYME_MLIR_IMPLEMENTATIONS_COMMON diff --git a/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td index 78630d65664..0e968aa8d32 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td @@ -15,3 +15,33 @@ def : MLIRDerivative<"complex", "MulOp", (Op $x, $y), (CMul (DiffeRet), $x) ] >; + +def : MLIRDerivative<"complex", "ImOp", (Op $x), + [ + (ComplexCreate + (TypeOf $x), + (ConstantFP<"0", "arith", "ConstantOp">), + (NegF (DiffeRet)) + ) + ], + (ComplexIm (Shadow $x)) + >; + +def : MLIRDerivative<"complex", "ReOp", (Op $x), + [ + (ComplexCreate + (TypeOf $x), + (DiffeRet), + (ConstantFP<"0", "arith", "ConstantOp">) + ) + ], + (ComplexRe (Shadow $x)) + >; + +def : MLIRDerivative<"complex", "CreateOp", (Op $re, $im), + [ + (ComplexRe (DiffeRet)), + (ComplexIm (DiffeRet)) + ] + >; + diff --git a/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp index 2833eeb4472..3a5d15d943d 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MathAutoDiffOpInterfaceImpl.cpp @@ -16,6 +16,7 @@ #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" diff --git a/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td index d52603f162e..d7a53d7f8ee 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/MathDerivatives.td @@ -10,10 +10,10 @@ def : MLIRDerivative<"math", "ExpOp", (Op $x), (CheckedMulF (DiffeRet), (ExpF $x)) ] >; -def : MLIRDerivative<"math", "SinOp", (Op $x), +def : MLIRDerivative<"math", "SinOp", (Op $x), [ (CheckedMulF (DiffeRet), (CosF $x)) - ] + ] >; def : MLIRDerivative<"math", "SqrtOp", (Op $x), [ @@ -27,7 +27,13 @@ def : MLIRDerivative<"math", "AtanOp", (Op $x), >; def : MLIRDerivative<"math", "AbsFOp", (Op $x), [ - // TODO: handle complex - (Arith_Select (CmpF (Arith_OGE), $x, (ConstantFP<"0","arith","ConstantOp"> $x)), (DiffeRet), (NegF (DiffeRet))) + (SelectIfComplex $x, + (ComplexCreate + (TypeOf $x), + (DivF (ComplexRe $x), (AbsF $x)), + (NegF (DivF (ComplexIm $x), (AbsF $x))) + ), + (Arith_Select (CmpF (Arith_OGE), $x, (ConstantFP<"0","arith","ConstantOp"> $x)), (DiffeRet), (NegF (DiffeRet))) + ) ] >; From d35def9802583c563147bd85b4cbbcc36a7817ab Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Thu, 30 Oct 2025 09:37:50 +0100 Subject: [PATCH 06/26] update --- enzyme/deferred.mlir | 2 -- 1 file changed, 2 deletions(-) diff --git a/enzyme/deferred.mlir b/enzyme/deferred.mlir index 923bad3c6f6..95eb760748e 100644 --- a/enzyme/deferred.mlir +++ b/enzyme/deferred.mlir @@ -27,8 +27,6 @@ module { return } - // if there is a single block in the augmented primal and reverse, then the push/pop - // can be replaced with just the values carried over from primal to reverse. enzyme.custom_reverse_rule @reverse_f { %0 = "enzyme.init"() : () -> !enzyme.Cache %1 = "enzyme.init"() : () -> !enzyme.Cache From 75a1af72adf1e0b04251102d7675aafaf5c30b42 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Fri, 31 Oct 2025 10:37:59 +0100 Subject: [PATCH 07/26] tape type --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 8c0a8495cf3..f2bfe99a928 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -206,7 +206,7 @@ def AutoDiffOp : Enzyme_Op<"autodiff", }]; } -def AnyTape : Type($_self)">, "enzyme tape">; +def AnyTape : Type($_self)">, "enzyme tape">; def AutoDiffDeferredPrimalOp : Enzyme_Op<"autodiff_deferred_primal", [DeclareOpInterfaceMethods]> { From 9590a19747459241b48d9b0c863334ef7555c3b1 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Fri, 31 Oct 2025 14:12:09 +0100 Subject: [PATCH 08/26] fmt --- enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 91cbfe07d8e..0f3b20a2753 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -543,8 +543,8 @@ AutoDiffDeferredPrimalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } -LogicalResult -AutoDiffDeferredReverseOp::verifySymbolUses(SymbolTableCollection &symbolTable) { +LogicalResult AutoDiffDeferredReverseOp::verifySymbolUses( + SymbolTableCollection &symbolTable) { // TODO: Verify that the result type is same as the type of the referenced // func.func op. auto global = From 6a7c0cff00dbbffd215e2df96baa4e84a3fc4e82 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:14:47 +0100 Subject: [PATCH 09/26] update --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 14 ++++++++++---- enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 6 +++--- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 10 +++++----- .../Passes/LowerEnzymeCustomRulesToFuncPass.cpp | 2 ++ enzyme/deferred.mlir | 4 ++-- 5 files changed, 22 insertions(+), 14 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index f2bfe99a928..ca1f001a17a 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -208,7 +208,7 @@ def AutoDiffOp : Enzyme_Op<"autodiff", def AnyTape : Type($_self)">, "enzyme tape">; -def AutoDiffDeferredPrimalOp : Enzyme_Op<"autodiff_deferred_primal", +def AutoDiffSplitModePrimalOp : Enzyme_Op<"autodiff_split_mode.primal", [DeclareOpInterfaceMethods]> { let summary = "Runs an augmented primal that can later be used to generate reverse with enzyme.autodiff_deferred_reverse"; let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr:$width, DefaultValuedAttr:$strong_zero); @@ -219,7 +219,7 @@ def AutoDiffDeferredPrimalOp : Enzyme_Op<"autodiff_deferred_primal", }]; } -def AutoDiffDeferredReverseOp : Enzyme_Op<"autodiff_deferred_reverse", +def AutoDiffSplitModeReverseOp : Enzyme_Op<"autodiff_split_mode.reverse", [DeclareOpInterfaceMethods]> { let summary = "Runs the reverse from an enzyme.autodiff_deferred_primal result"; let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, AnyTape:$tape, @@ -242,7 +242,10 @@ def CustomReverseRuleOp : Enzyme_Op<"custom_reverse_rule", [IsolatedFromAbove]> }]; } -def AugmentedPrimalOp : Enzyme_Op<"augmented_primal", [HasParent<"CustomReverseRuleOp">]> { +def AugmentedPrimalOp : Enzyme_Op<"augmented_primal", [ + HasParent<"CustomReverseRuleOp">, + AutomaticAllocationScope, + AffineScope]> { let summary = "Defines the augmented primal for a custom reverse rule"; let arguments = (ins); let regions = (region AnyRegion:$body); @@ -253,7 +256,10 @@ def AugmentedPrimalOp : Enzyme_Op<"augmented_primal", [HasParent<"CustomReverseR }]; } -def ReverseOp : Enzyme_Op<"reverse", [HasParent<"CustomReverseRuleOp">]> { +def ReverseOp : Enzyme_Op<"reverse", [ + HasParent<"CustomReverseRuleOp">, + AutomaticAllocationScope, + AffineScope]> { let summary = "Defined the reverse for a custom rule."; let arguments = (ins); let regions = (region AnyRegion:$body); diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 0f3b20a2753..83f0be07579 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -530,8 +530,8 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // AutoDiffDeferredOp //===----------------------------------------------------------------------===// -LogicalResult -AutoDiffDeferredPrimalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { +LogicalResult AutoDiffSplitModePrimalOp::verifySymbolUses( + SymbolTableCollection &symbolTable) { // TODO: Verify that the result type is same as the type of the referenced // func.func op. auto global = @@ -543,7 +543,7 @@ AutoDiffDeferredPrimalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } -LogicalResult AutoDiffDeferredReverseOp::verifySymbolUses( +LogicalResult AutoDiffSplitModeReverseOp::verifySymbolUses( SymbolTableCollection &symbolTable) { // TODO: Verify that the result type is same as the type of the referenced // func.func op. diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index 26f4c6f0229..a5e09ec99f1 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -351,12 +351,12 @@ struct DifferentiatePass } LogicalResult HandleSplitModeAutoDiff(SymbolTableCollection &symbolTable, - enzyme::AutoDiffDeferredPrimalOp CI) { + enzyme::AutoDiffSplitModePrimalOp CI) { auto tape = CI.getTape(); SmallVector reverseCalls; for (auto user : tape.getUsers()) { - if (isa(user)) + if (isa(user)) reverseCalls.push_back(user); else { user->emitError() << "todo: unsupported tape usage"; @@ -497,7 +497,7 @@ struct DifferentiatePass tape = primalCall.getTape(); for (auto tapeUser : tape.getUsers()) { if (auto revCall = - dyn_cast(tapeUser)) { + dyn_cast(tapeUser)) { OpBuilder builder(revCall); auto newRevCall = builder.create( @@ -575,7 +575,7 @@ struct DifferentiatePass { SmallVector toLower; - op->walk([&](enzyme::AutoDiffDeferredPrimalOp dop) { + op->walk([&](enzyme::AutoDiffSplitModePrimalOp dop) { auto *symbolOp = symbolTable.lookupNearestSymbolFrom(dop, dop.getFnAttr()); auto callableOp = cast(symbolOp); @@ -585,7 +585,7 @@ struct DifferentiatePass }); for (auto T : toLower) { - if (auto F = dyn_cast(T)) { + if (auto F = dyn_cast(T)) { auto res = HandleSplitModeAutoDiff(symbolTable, F); if (!res.succeeded()) { signalPassFailure(); diff --git a/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp index f45698aee2f..9780ec22068 100644 --- a/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// // +// This file defines a pass to lower enzyme custom rules to the func dialect. +// //===----------------------------------------------------------------------===// #include "Dialect/Ops.h" diff --git a/enzyme/deferred.mlir b/enzyme/deferred.mlir index 95eb760748e..c8f70ca047f 100644 --- a/enzyme/deferred.mlir +++ b/enzyme/deferred.mlir @@ -11,7 +11,7 @@ module { %a = arith.constant 1.0 : f32 %b = arith.constant 1.0 : f32 - %r, %tape = enzyme.autodiff_deferred_primal @mul(%a, %b) { + %r, %tape = enzyme.autodiff_split_mode.primal @mul(%a, %b) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme] } : (f32, f32) -> (f32, !enzyme.Tape) @@ -19,7 +19,7 @@ module { // --- %dres = arith.constant 1.0 : f32 - %da, %db = enzyme.autodiff_deferred_reverse @mul(%dres, %tape) { + %da, %db = enzyme.autodiff_split_mode.reverse @mul(%dres, %tape) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme] } : (f32, !enzyme.Tape) -> (f32, f32) From 122f9ccc6940835902a8f9392efd3bcf376a99f3 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:57:24 +0100 Subject: [PATCH 10/26] dup args --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 6 +-- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 54 ++++++++++++++++--- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 4 +- .../LowerEnzymeCustomRulesToFuncPass.cpp | 21 +++++--- 4 files changed, 67 insertions(+), 18 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index ca1f001a17a..75075a720d1 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -242,7 +242,7 @@ def CustomReverseRuleOp : Enzyme_Op<"custom_reverse_rule", [IsolatedFromAbove]> }]; } -def AugmentedPrimalOp : Enzyme_Op<"augmented_primal", [ +def CustomReverseRuleAugmentedPrimalOp : Enzyme_Op<"custom_reverse_rule.augmented_primal", [ HasParent<"CustomReverseRuleOp">, AutomaticAllocationScope, AffineScope]> { @@ -256,7 +256,7 @@ def AugmentedPrimalOp : Enzyme_Op<"augmented_primal", [ }]; } -def ReverseOp : Enzyme_Op<"reverse", [ +def CustomReverseRuleReverseOp : Enzyme_Op<"custom_reverse_rule.reverse", [ HasParent<"CustomReverseRuleOp">, AutomaticAllocationScope, AffineScope]> { @@ -331,7 +331,7 @@ def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]> } def YieldOp : Enzyme_Op<"yield", [Pure, ReturnLike, Terminator, - ParentOneOf<["AutoDiffRegionOp", "LoopOp", "ReverseOp", "AugmentedPrimalOp", "CustomReverseRuleOp"]>]> { + ParentOneOf<["AutoDiffRegionOp", "LoopOp", "CustomReverseRuleReverseOp", "CustomReverseRuleAugmentedPrimalOp", "CustomReverseRuleOp"]>]> { let summary = "Yield values at the end of an autodiff_region, loop op, reverse op, aug primal op or custom reverse rule op"; let arguments = (ins Variadic:$operands); let assemblyFormat = [{ diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index faf81772dfe..e5fcd8b0bef 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -307,6 +307,18 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( return nf; } +static mlir::enzyme::ActivityAttr activityFromDiffeType(mlir::MLIRContext *ctx, + DIFFE_TYPE ty) { + auto activity = mlir::enzyme::Activity::enzyme_active; + switch (ty) { + case DIFFE_TYPE::DUP_ARG: + activity = mlir::enzyme::Activity::enzyme_dup; + default: + break; + }; + return mlir::enzyme::ActivityAttr::get(ctx, activity); +} + FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( FunctionOpInterface fn, std::vector retType, std::vector constants, MTypeAnalysis &TA, @@ -331,7 +343,6 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( else nonconstant_values.insert(arg); } - IRMapping invertedPointers; SmallVector returnPrimalsP(returnPrimals.begin(), returnPrimals.end()); SmallVector returnShadowsP(returnShadows.begin(), returnShadows.end()); @@ -340,13 +351,11 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( SmallVector argActivityAttrs; for (auto act : constants) - argActivityAttrs.push_back(mlir::enzyme::ActivityAttr::get( - fn.getContext(), mlir::enzyme::Activity::enzyme_active)); + argActivityAttrs.push_back(activityFromDiffeType(fn.getContext(), act)); SmallVector retActivityAttrs; for (auto act : retType) - retActivityAttrs.push_back(mlir::enzyme::ActivityAttr::get( - fn.getContext(), mlir::enzyme::Activity::enzyme_active)); + retActivityAttrs.push_back(activityFromDiffeType(fn.getContext(), act)); auto argActivityAttr = ArrayAttr::get(fn.getContext(), argActivityAttrs); auto retActivityAttr = ArrayAttr::get(fn.getContext(), retActivityAttrs); @@ -357,6 +366,10 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( auto ruleNameAttr = FlatSymbolRefAttr::get( fn.getContext(), customRuleName.toStringRef(nameBuf)); + SmallVector argTys( + cast(fn.getFunctionType()).getInputs().begin(), + cast(fn.getFunctionType()).getInputs().end()); + OpBuilder builder(fn); auto customRule = enzyme::CustomReverseRuleOp::create( builder, fn.getLoc(), ruleNameAttr, TypeAttr::get(fn.getFunctionType()), @@ -367,7 +380,8 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( OpBuilder ruleBuilder(ruleBody, ruleBody->begin()); - auto reverse = ruleBuilder.create(fn.getLoc()); + auto reverse = + ruleBuilder.create(fn.getLoc()); ruleBuilder.create(fn.getLoc(), ValueRange{}); ruleBuilder.setInsertionPoint(reverse); @@ -391,6 +405,31 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( cloneInto(&fn.getFunctionBody(), &newFunc.getFunctionBody(), originalToNew, originalToNewOps); + Block *fnEntry = &newFunc.getFunctionBody().front(); + IRMapping invertedPointers; + + SmallVector newArgTys; + + int numDup = 0; + for (auto [act, arg] : llvm::zip_equal( + constants, fn.getFunctionBody().front().getArguments())) { + newArgTys.push_back(arg.getType()); + if (act == DIFFE_TYPE::DUP_ARG) { + numDup++; + auto shadow = fnEntry->insertArgument(arg.getArgNumber() + numDup, + arg.getType(), // shadow + arg.getLoc()); + // argTys.insert(arg.getArgNumber() + numDup, arg.getType()); + newArgTys.push_back(arg.getType()); + invertedPointers.map(arg, shadow); + } + } + + auto newFuncType = + FunctionType::get(newFunc.getContext(), newArgTys, + cast(fn.getFunctionType()).getResults()); + newFunc.setFunctionTypeAttr(TypeAttr::get(newFuncType)); + llvm::errs() << "new func = " << newFunc << "\n"; MGradientUtilsReverse *gutils = new MGradientUtilsReverse( @@ -450,7 +489,8 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( ruleBuilder.setInsertionPoint(reverse); auto augmentedPrimal = - ruleBuilder.create(fn.getLoc()); + ruleBuilder.create( + fn.getLoc()); augmentedPrimal.getBody().takeBody(newFunc.getFunctionBody()); for (Block &b : augmentedPrimal.getBody()) { if (b.getNumSuccessors() == 0) { diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index a5e09ec99f1..1580107fd56 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -372,13 +372,13 @@ struct DifferentiatePass assert(fn); if (CI.getActivity().size() != fn.getNumArguments()) { llvm::errs() << "Incorrect number of argument activities on autodiff op" - << "CI: " << CI << ", expected " << fn.getNumArguments() + << " CI: " << CI << ", expected " << fn.getNumArguments() << " found " << CI.getActivity().size() << "\n"; return failure(); } if (CI.getRetActivity().size() != fn.getNumResults()) { llvm::errs() << "Incorrect number of result activities on autodiff op" - << "CI: " << CI << ", expected " << fn.getNumResults() + << " CI: " << CI << ", expected " << fn.getNumResults() << " found " << CI.getRetActivity().size() << "\n"; return failure(); } diff --git a/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp index 9780ec22068..1d136069ee6 100644 --- a/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp @@ -49,17 +49,17 @@ lowerCustomReverseRuleToFunc(enzyme::CustomReverseRuleOp revRule) { Block *bodyDef = &revRule.getBody().front(); - enzyme::AugmentedPrimalOp primal = nullptr; - enzyme::ReverseOp reverse = nullptr; + enzyme::CustomReverseRuleAugmentedPrimalOp primal = nullptr; + enzyme::CustomReverseRuleReverseOp reverse = nullptr; for (Operation &op : *bodyDef) { - if (auto AP = dyn_cast(op)) { + if (auto AP = dyn_cast(op)) { if (primal) { AP->emitError() << "multiple augmented primal ops in a custom rule"; return failure(); } primal = AP; - } else if (auto RO = dyn_cast(op)) { + } else if (auto RO = dyn_cast(op)) { if (reverse) { RO->emitError() << "multiple reverse op in a custom rule"; return failure(); @@ -79,8 +79,7 @@ lowerCustomReverseRuleToFunc(enzyme::CustomReverseRuleOp revRule) { auto funcType = revRule.getFunctionType(); - SmallVector primalArgTypes(funcType.getInputs().begin(), - funcType.getInputs().end()); + SmallVector primalArgTypes; SmallVector primalResultTypes(funcType.getResults().begin(), funcType.getResults().end()); @@ -110,8 +109,14 @@ lowerCustomReverseRuleToFunc(enzyme::CustomReverseRuleOp revRule) { case mlir::enzyme::Activity::enzyme_active: case mlir::enzyme::Activity::enzyme_activenoneed: reverseResultTypes.push_back(argTy); + primalArgTypes.push_back(argTy); break; case mlir::enzyme::Activity::enzyme_const: + primalArgTypes.push_back(argTy); + break; + case mlir::enzyme::Activity::enzyme_dup: + primalArgTypes.push_back(argTy); + primalArgTypes.push_back(argTy); break; default: llvm_unreachable("todo"); @@ -232,6 +237,10 @@ lowerCustomReverseRuleToFunc(enzyme::CustomReverseRuleOp revRule) { OpBuilder builder(CAP); auto primalCall = builder.create(CAP.getLoc(), primalFunc, CAP->getOperands()); + for (auto [oldRes, newRes] : + llvm::zip(CAP.getOutputs(), primalCall->getResults())) { + oldRes.replaceAllUsesWith(newRes); + } auto tape = CAP->getResult(CAP->getNumResults() - 1); for (auto tapeUser : tape.getUsers()) { From 9d16ca3bfb7cb35b6945e5a7ef98437f92481b69 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 3 Nov 2025 15:08:54 +0100 Subject: [PATCH 11/26] Shadow type --- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 22 ++------------ .../LowerEnzymeCustomRulesToFuncPass.cpp | 4 ++- enzyme/deferred.mlir | 30 ++++++++++++++++--- 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index e5fcd8b0bef..1ac253abccd 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -386,21 +386,6 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( ruleBuilder.setInsertionPoint(reverse); - // FunctionOpInterface newF = - // cast(fn->cloneWithoutRegions()); - // SymbolTable::setSymbolName( - // newF, StringAttr::get(fn->getContext(), name + "_primal")); - - // FunctionOpInterface newFRev = - // cast(fn->cloneWithoutRegions()); - // SymbolTable::setSymbolName( - // newFRev, StringAttr::get(fn->getContext(), name + "_reverse")); - // - // cloneInto(&fn.getFunctionBody(), &newF.getFunctionBody(), originalToNew, - // originalToNewOps); - // - // llvm::errs() << "fn = " << newF << "\n"; - auto newFunc = cast(fn->cloneWithoutRegions()); cloneInto(&fn.getFunctionBody(), &newFunc.getFunctionBody(), originalToNew, originalToNewOps); @@ -416,11 +401,10 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( newArgTys.push_back(arg.getType()); if (act == DIFFE_TYPE::DUP_ARG) { numDup++; + auto shadowType = cast(arg.getType()).getShadowType(width); auto shadow = fnEntry->insertArgument(arg.getArgNumber() + numDup, - arg.getType(), // shadow - arg.getLoc()); - // argTys.insert(arg.getArgNumber() + numDup, arg.getType()); - newArgTys.push_back(arg.getType()); + shadowType, arg.getLoc()); + newArgTys.push_back(shadowType); invertedPointers.map(arg, shadow); } } diff --git a/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp index 1d136069ee6..d582a62be9a 100644 --- a/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "Dialect/Ops.h" +#include "Interfaces/AutoDiffTypeInterface.h" #include "Passes/Passes.h" #include "Passes/RemovalUtils.h" @@ -116,7 +117,8 @@ lowerCustomReverseRuleToFunc(enzyme::CustomReverseRuleOp revRule) { break; case mlir::enzyme::Activity::enzyme_dup: primalArgTypes.push_back(argTy); - primalArgTypes.push_back(argTy); + primalArgTypes.push_back( + cast(argTy).getShadowType(/*width*/ 1)); break; default: llvm_unreachable("todo"); diff --git a/enzyme/deferred.mlir b/enzyme/deferred.mlir index c8f70ca047f..30c1c04cf11 100644 --- a/enzyme/deferred.mlir +++ b/enzyme/deferred.mlir @@ -31,7 +31,7 @@ module { %0 = "enzyme.init"() : () -> !enzyme.Cache %1 = "enzyme.init"() : () -> !enzyme.Cache - enzyme.augmented_primal { + enzyme.custom_reverse_rule.augmented_primal { ^bb0(%a: f32, %b: f32): "enzyme.push"(%0, %a) : (!enzyme.Cache, f32) -> () "enzyme.push"(%1, %b) : (!enzyme.Cache, f32) -> () @@ -41,7 +41,7 @@ module { enzyme.yield %res : f32 } - enzyme.reverse { + enzyme.custom_reverse_rule.reverse { ^bb0(%dres: f32): %a = "enzyme.pop"(%0) : (!enzyme.Cache) -> f32 %b = "enzyme.pop"(%1) : (!enzyme.Cache) -> f32 @@ -74,13 +74,13 @@ module { enzyme.custom_reverse_rule @exp_f32 { %cache = "enzyme.init"() : () -> !enzyme.Cache - enzyme.augmented_primal { + enzyme.custom_reverse_rule.augmented_primal { ^bb0(%arg0: f32): %res = math.exp %arg0 : f32 "enzyme.push"(%cache, %res) : (!enzyme.Cache, f32) -> () enzyme.yield %res : f32 } - enzyme.reverse { + enzyme.custom_reverse_rule.reverse { ^bb0(%dres: f32): %res = "enzyme.pop"(%cache) : (!enzyme.Cache) -> (f32) %darg0 = arith.mulf %dres, %res : f32 @@ -93,6 +93,28 @@ module { function_type = (f32) -> (f32) } + func.func @f_dup(%a: !llvm.ptr) -> f32 { + %0 = llvm.load %a : !llvm.ptr -> f32 + return %0 : f32 + } + + func.func @ff_dup(%a: !llvm.ptr, %b: !llvm.ptr) -> f32 { + %r, %tape = enzyme.autodiff_split_mode.primal @f_dup(%a, %b) { + activity=[#enzyme], + ret_activity=[#enzyme] + } : (!llvm.ptr, !llvm.ptr) -> (f32, !enzyme.Tape) + + // + + %dres = arith.constant 1.0 : f32 + enzyme.autodiff_split_mode.reverse @f_dup(%dres, %tape) { + activity=[#enzyme], + ret_activity=[#enzyme] + } : (f32, !enzyme.Tape) -> () + + return %r : f32 + } + // TODO: split mode is implemented using custom_reverse_rule + AD // func.func @mul_primal(%a: f32, %b: f32) -> (f32, !enzyme.Tape) { From 29c0bdd6f0c298365eabd361861cb733ad04aad5 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 3 Nov 2025 15:27:09 +0100 Subject: [PATCH 12/26] fmt --- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 1ac253abccd..546622391c8 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -401,7 +401,8 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( newArgTys.push_back(arg.getType()); if (act == DIFFE_TYPE::DUP_ARG) { numDup++; - auto shadowType = cast(arg.getType()).getShadowType(width); + auto shadowType = + cast(arg.getType()).getShadowType(width); auto shadow = fnEntry->insertArgument(arg.getArgNumber() + numDup, shadowType, arg.getLoc()); newArgTys.push_back(shadowType); From d78841f42c7024c78aa56a200e45aabb3dcbe744 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:31:50 +0100 Subject: [PATCH 13/26] Symbol Interface for custom rule --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 16 +++---- enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 43 ++++++++++++++++++- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 6 +-- 3 files changed, 51 insertions(+), 14 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 75075a720d1..bb2809c36ea 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -231,14 +231,14 @@ def AutoDiffSplitModeReverseOp : Enzyme_Op<"autodiff_split_mode.reverse", }]; } -def CustomReverseRuleOp : Enzyme_Op<"custom_reverse_rule", [IsolatedFromAbove]> { +def CustomReverseRuleOp : Enzyme_Op<"custom_reverse_rule", [IsolatedFromAbove, Symbol]> { let summary = "Parent operation for custom reverse rule declaration."; - let arguments = (ins FlatSymbolRefAttr:$name, TypeAttrOf:$function_type, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); + let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf:$function_type, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity); let regions = (region AnyRegion:$body); let results = (outs); let assemblyFormat = [{ - $name $body attr-dict-with-keyword + $sym_name $body attr-dict-with-keyword }]; } @@ -269,9 +269,8 @@ def CustomReverseRuleReverseOp : Enzyme_Op<"custom_reverse_rule.reverse", [ }]; } -def CallAugmentedPrimalOp : Enzyme_Op<"call_augmented_primal", - // [DeclareOpInterfaceMethods]> { - []> { +def CallAugmentedPrimalOp : Enzyme_Op<"call_augmented_primal", [ + DeclareOpInterfaceMethods]> { let summary = ""; let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs); let results = (outs Variadic:$outputs, AnyTape:$tape); @@ -281,9 +280,8 @@ def CallAugmentedPrimalOp : Enzyme_Op<"call_augmented_primal", }]; } -def CallCustomReverseOp : Enzyme_Op<"call_custom_reverse", - // [DeclareOpInterfaceMethods]> { - []> { +def CallCustomReverseOp : Enzyme_Op<"call_custom_reverse", [ + DeclareOpInterfaceMethods]> { let summary = ""; let arguments = (ins FlatSymbolRefAttr:$fn, Variadic:$inputs, AnyTape:$tape); let results = (outs Variadic:$outputs); diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 83f0be07579..8b79b3c8b61 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -527,7 +527,7 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } //===----------------------------------------------------------------------===// -// AutoDiffDeferredOp +// AutoDiffSplitModePrimalOp //===----------------------------------------------------------------------===// LogicalResult AutoDiffSplitModePrimalOp::verifySymbolUses( @@ -543,6 +543,10 @@ LogicalResult AutoDiffSplitModePrimalOp::verifySymbolUses( return success(); } +//===----------------------------------------------------------------------===// +// AutoDiffSplitModeReverseOp +//===----------------------------------------------------------------------===// + LogicalResult AutoDiffSplitModeReverseOp::verifySymbolUses( SymbolTableCollection &symbolTable) { // TODO: Verify that the result type is same as the type of the referenced @@ -556,6 +560,43 @@ LogicalResult AutoDiffSplitModeReverseOp::verifySymbolUses( return success(); } +//===----------------------------------------------------------------------===// +// CallAugmentedPrimalOp +//===----------------------------------------------------------------------===// + +LogicalResult CallAugmentedPrimalOp::verifySymbolUses( + SymbolTableCollection &symbolTable) { + auto global = + symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); + if (!global) + return emitOpError("'") + << getFn() << "' does not reference a valid custom reverse rule"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// CustomReverseRuleOp +//===----------------------------------------------------------------------===// + +// mlir::StringAttr + +//===----------------------------------------------------------------------===// +// CallCustomReverseOp +//===----------------------------------------------------------------------===// + +LogicalResult CallCustomReverseOp::verifySymbolUses( + SymbolTableCollection &symbolTable) { + auto global = + symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); + if (!global) + return emitOpError("'") + << getFn() << "' does not reference a valid custom reverse rule"; + + return success(); +} + + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 546622391c8..85deaeea0c4 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -363,7 +363,7 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( auto customRuleName = name + "_reverse_rule"; SmallVector nameBuf; - auto ruleNameAttr = FlatSymbolRefAttr::get( + auto ruleNameAttr = StringAttr::get( fn.getContext(), customRuleName.toStringRef(nameBuf)); SmallVector argTys( @@ -415,8 +415,6 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( cast(fn.getFunctionType()).getResults()); newFunc.setFunctionTypeAttr(TypeAttr::get(newFuncType)); - llvm::errs() << "new func = " << newFunc << "\n"; - MGradientUtilsReverse *gutils = new MGradientUtilsReverse( *this, newFunc, fn, TA, invertedPointers, returnPrimalsP, returnShadowsP, constant_values, nonconstant_values, retType, constants, originalToNew, @@ -490,5 +488,5 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( newFunc->erase(); - return ruleNameAttr; + return FlatSymbolRefAttr::get(ruleNameAttr); } From 0cf00d6af7593eee15854ff3958448d2b4335cc6 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 3 Nov 2025 16:32:38 +0100 Subject: [PATCH 14/26] Cleanup --- enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 8b79b3c8b61..4988579db37 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -575,12 +575,6 @@ LogicalResult CallAugmentedPrimalOp::verifySymbolUses( return success(); } -//===----------------------------------------------------------------------===// -// CustomReverseRuleOp -//===----------------------------------------------------------------------===// - -// mlir::StringAttr - //===----------------------------------------------------------------------===// // CallCustomReverseOp //===----------------------------------------------------------------------===// From be8f242bda9559e609e3e32b8a32f3df8188a430 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 3 Nov 2025 17:29:11 +0100 Subject: [PATCH 15/26] Each subop has its own function type and custom parser/printer --- enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 13 ++- enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 81 +++++++++++++++++-- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 9 ++- enzyme/deferred.mlir | 12 +-- 4 files changed, 90 insertions(+), 25 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index bb2809c36ea..40421d3fe65 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -247,13 +247,11 @@ def CustomReverseRuleAugmentedPrimalOp : Enzyme_Op<"custom_reverse_rule.augmente AutomaticAllocationScope, AffineScope]> { let summary = "Defines the augmented primal for a custom reverse rule"; - let arguments = (ins); + let arguments = (ins TypeAttrOf:$function_type); let regions = (region AnyRegion:$body); let results = (outs); - let assemblyFormat = [{ - $body attr-dict-with-keyword - }]; + let hasCustomAssemblyFormat = 1; } def CustomReverseRuleReverseOp : Enzyme_Op<"custom_reverse_rule.reverse", [ @@ -261,12 +259,11 @@ def CustomReverseRuleReverseOp : Enzyme_Op<"custom_reverse_rule.reverse", [ AutomaticAllocationScope, AffineScope]> { let summary = "Defined the reverse for a custom rule."; - let arguments = (ins); + let arguments = (ins TypeAttrOf:$function_type); let regions = (region AnyRegion:$body); let results = (outs); - let assemblyFormat = [{ - $body attr-dict-with-keyword - }]; + + let hasCustomAssemblyFormat = 1; } def CallAugmentedPrimalOp : Enzyme_Op<"call_augmented_primal", [ diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 4988579db37..130a8d3ac4c 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -560,14 +561,80 @@ LogicalResult AutoDiffSplitModeReverseOp::verifySymbolUses( return success(); } +static ParseResult parseAugmentedFn(OpAsmParser &parser, + OperationState &result) { + SmallVector argTys, resTys; + SmallVector resAttrs; + + bool isVariadic = false; + SmallVector arguments; + if (failed(function_interface_impl::parseFunctionSignatureWithArguments( + parser, /*allowVariadic*/ false, arguments, isVariadic, resTys, + resAttrs))) + return failure(); + + auto *body = result.addRegion(); + if (failed( + parser.parseRegion(*body, arguments, /*enableNameShadowing*/ false))) + return failure(); + + result.addAttribute( + "function_type", + TypeAttr::get(FunctionType::get(result.getContext(), argTys, resTys))); + + return success(); +} + +static void printAugmentedFn(OpAsmPrinter &p, FunctionType fnType, + Region &body) { + p << ' '; + + call_interface_impl::printFunctionSignature( + p, fnType.getInputs(), nullptr, /*isVariadic*/ false, fnType.getResults(), + nullptr, &body, /*printEmptyResult*/ false); + + p << ' '; + p.printRegion(body, /*printEntryBlockArgs*/ false, + /*printBlockTerminators*/ true); +} + +//===----------------------------------------------------------------------===// +// CustomReverseRuleAugmentedPrimalOp +//===----------------------------------------------------------------------===// + +mlir::ParseResult +CustomReverseRuleAugmentedPrimalOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseAugmentedFn(parser, result); +} + +void CustomReverseRuleAugmentedPrimalOp::print(OpAsmPrinter &p) { + printAugmentedFn(p, getFunctionType(), getBody()); +} + +//===----------------------------------------------------------------------===// +// CustomReverseRuleReverseOp +//===----------------------------------------------------------------------===// + +mlir::ParseResult CustomReverseRuleReverseOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseAugmentedFn(parser, result); +} + +void CustomReverseRuleReverseOp::print(OpAsmPrinter &p) { + auto rule = cast(this->getParentOp()); + printAugmentedFn(p, getFunctionType(), getBody()); +} + //===----------------------------------------------------------------------===// // CallAugmentedPrimalOp //===----------------------------------------------------------------------===// -LogicalResult CallAugmentedPrimalOp::verifySymbolUses( - SymbolTableCollection &symbolTable) { +LogicalResult +CallAugmentedPrimalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto global = - symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); + symbolTable.lookupNearestSymbolFrom( + *this, getFnAttr()); if (!global) return emitOpError("'") << getFn() << "' does not reference a valid custom reverse rule"; @@ -579,10 +646,11 @@ LogicalResult CallAugmentedPrimalOp::verifySymbolUses( // CallCustomReverseOp //===----------------------------------------------------------------------===// -LogicalResult CallCustomReverseOp::verifySymbolUses( - SymbolTableCollection &symbolTable) { +LogicalResult +CallCustomReverseOp::verifySymbolUses(SymbolTableCollection &symbolTable) { auto global = - symbolTable.lookupNearestSymbolFrom(*this, getFnAttr()); + symbolTable.lookupNearestSymbolFrom( + *this, getFnAttr()); if (!global) return emitOpError("'") << getFn() << "' does not reference a valid custom reverse rule"; @@ -590,7 +658,6 @@ LogicalResult CallCustomReverseOp::verifySymbolUses( return success(); } - //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 85deaeea0c4..b1151e0e58d 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -380,8 +380,13 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( OpBuilder ruleBuilder(ruleBody, ruleBody->begin()); + SmallVector revInputTypes, revOutputTypes, primalInputTypes, primalOutputTypes; + + auto revFuncType = FunctionType::get(fn.getContext(), revInputTypes, revOutputTypes); + auto primalFuncType = FunctionType::get(fn.getContext(), primalInputTypes, primalOutputTypes); + auto reverse = - ruleBuilder.create(fn.getLoc()); + ruleBuilder.create(fn.getLoc(), revFuncType); ruleBuilder.create(fn.getLoc(), ValueRange{}); ruleBuilder.setInsertionPoint(reverse); @@ -473,7 +478,7 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( ruleBuilder.setInsertionPoint(reverse); auto augmentedPrimal = ruleBuilder.create( - fn.getLoc()); + fn.getLoc(), primalFuncType); augmentedPrimal.getBody().takeBody(newFunc.getFunctionBody()); for (Block &b : augmentedPrimal.getBody()) { if (b.getNumSuccessors() == 0) { diff --git a/enzyme/deferred.mlir b/enzyme/deferred.mlir index 30c1c04cf11..a092dcd3b2e 100644 --- a/enzyme/deferred.mlir +++ b/enzyme/deferred.mlir @@ -31,8 +31,7 @@ module { %0 = "enzyme.init"() : () -> !enzyme.Cache %1 = "enzyme.init"() : () -> !enzyme.Cache - enzyme.custom_reverse_rule.augmented_primal { - ^bb0(%a: f32, %b: f32): + enzyme.custom_reverse_rule.augmented_primal (%a: f32, %b: f32) -> f32 { "enzyme.push"(%0, %a) : (!enzyme.Cache, f32) -> () "enzyme.push"(%1, %b) : (!enzyme.Cache, f32) -> () @@ -41,8 +40,7 @@ module { enzyme.yield %res : f32 } - enzyme.custom_reverse_rule.reverse { - ^bb0(%dres: f32): + enzyme.custom_reverse_rule.reverse (%dres: f32) -> (f32, f32) { %a = "enzyme.pop"(%0) : (!enzyme.Cache) -> f32 %b = "enzyme.pop"(%1) : (!enzyme.Cache) -> f32 @@ -74,14 +72,12 @@ module { enzyme.custom_reverse_rule @exp_f32 { %cache = "enzyme.init"() : () -> !enzyme.Cache - enzyme.custom_reverse_rule.augmented_primal { - ^bb0(%arg0: f32): + enzyme.custom_reverse_rule.augmented_primal (%arg0: f32) { %res = math.exp %arg0 : f32 "enzyme.push"(%cache, %res) : (!enzyme.Cache, f32) -> () enzyme.yield %res : f32 } - enzyme.custom_reverse_rule.reverse { - ^bb0(%dres: f32): + enzyme.custom_reverse_rule.reverse (%dres: f32) -> f32 { %res = "enzyme.pop"(%cache) : (!enzyme.Cache) -> (f32) %darg0 = arith.mulf %dres, %res : f32 enzyme.yield %darg0 : f32 From 5f8c9c1963de3717e1d8aa33ad51c20e40f7221a Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 3 Nov 2025 17:31:59 +0100 Subject: [PATCH 16/26] fmt --- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index b1151e0e58d..3c1c4de5760 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -363,8 +363,8 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( auto customRuleName = name + "_reverse_rule"; SmallVector nameBuf; - auto ruleNameAttr = StringAttr::get( - fn.getContext(), customRuleName.toStringRef(nameBuf)); + auto ruleNameAttr = + StringAttr::get(fn.getContext(), customRuleName.toStringRef(nameBuf)); SmallVector argTys( cast(fn.getFunctionType()).getInputs().begin(), @@ -380,13 +380,16 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( OpBuilder ruleBuilder(ruleBody, ruleBody->begin()); - SmallVector revInputTypes, revOutputTypes, primalInputTypes, primalOutputTypes; + SmallVector revInputTypes, revOutputTypes, primalInputTypes, + primalOutputTypes; - auto revFuncType = FunctionType::get(fn.getContext(), revInputTypes, revOutputTypes); - auto primalFuncType = FunctionType::get(fn.getContext(), primalInputTypes, primalOutputTypes); + auto revFuncType = + FunctionType::get(fn.getContext(), revInputTypes, revOutputTypes); + auto primalFuncType = + FunctionType::get(fn.getContext(), primalInputTypes, primalOutputTypes); - auto reverse = - ruleBuilder.create(fn.getLoc(), revFuncType); + auto reverse = ruleBuilder.create( + fn.getLoc(), revFuncType); ruleBuilder.create(fn.getLoc(), ValueRange{}); ruleBuilder.setInsertionPoint(reverse); From 83c885c0ee7841e0d654b436e7ef06c7be930378 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Mon, 3 Nov 2025 17:40:34 +0100 Subject: [PATCH 17/26] Model side effects in LLVMExt --- enzyme/BUILD | 6 +++++- enzyme/Enzyme/MLIR/Dialect/LLVMExt/LLVMExtOps.td | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/enzyme/BUILD b/enzyme/BUILD index 0302c51219c..d3455e995ca 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -253,6 +253,7 @@ td_library( includes = ["."], deps = [ "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfaces" ] ) @@ -319,7 +320,10 @@ gentbl_cc_library( ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "Enzyme/MLIR/Dialect/LLVMExt/LLVMExtOps.td", - deps = [":LLVMExtDialectTdFiles"], + deps = [ + ":LLVMExtDialectTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles" + ] ) gentbl_cc_library( diff --git a/enzyme/Enzyme/MLIR/Dialect/LLVMExt/LLVMExtOps.td b/enzyme/Enzyme/MLIR/Dialect/LLVMExt/LLVMExtOps.td index f1f8675f52c..0cf874f2aa1 100644 --- a/enzyme/Enzyme/MLIR/Dialect/LLVMExt/LLVMExtOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/LLVMExt/LLVMExtOps.td @@ -2,13 +2,14 @@ #define ENZYME_DIALECT_LLVMEXT_OPS_TD include "Dialect.td" +include "mlir/Interfaces/SideEffectInterfaces.td" def LLVMPtr : Type($_self)">>; def AllocOp : LLVMExtOp<"alloc", []> { let summary = "Allocates memory"; let arguments = (ins I64 : $size); - let results = (outs LLVMPtr : $result); + let results = (outs Res : $result); let assemblyFormat = [{ $size attr-dict `:` functional-type($size, results) @@ -17,7 +18,7 @@ def AllocOp : LLVMExtOp<"alloc", []> { def FreeOp : LLVMExtOp<"free", []> { let summary = "Frees memory"; - let arguments = (ins LLVMPtr : $ptr); + let arguments = (ins Arg : $ptr); let results = (outs); let assemblyFormat = [{ From a1b6d32f60a4deaee50d983b22bb089eaac389ee Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Tue, 4 Nov 2025 09:48:33 +0100 Subject: [PATCH 18/26] cmake --- enzyme/Enzyme/MLIR/Passes/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 43291dbbb28..85affd21463 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms EnzymeWrapPass.cpp InlineEnzymeRegions.cpp LowerLLVMExtPass.cpp + LowerEnzymeCustomRulesToFuncPass.cpp PrintActivityAnalysis.cpp PrintAliasAnalysis.cpp EnzymeToMemRef.cpp From a8b40870387efe1f8eea3e2cfd22d9dc88a04bce Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Wed, 5 Nov 2025 09:47:21 +0100 Subject: [PATCH 19/26] Operation::create --- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 19 +++++++++---------- .../MLIR/Passes/EnzymeBatchToTensorPass.cpp | 10 +++++----- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 9 +++++---- .../LowerEnzymeCustomRulesToFuncPass.cpp | 10 +++++----- 4 files changed, 24 insertions(+), 24 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 93dca22baf7..62d4abe78bc 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -388,9 +388,9 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( auto primalFuncType = FunctionType::get(fn.getContext(), primalInputTypes, primalOutputTypes); - auto reverse = ruleBuilder.create( - fn.getLoc(), revFuncType); - ruleBuilder.create(fn.getLoc(), ValueRange{}); + auto reverse = enzyme::CustomReverseRuleReverseOp::create( + ruleBuilder, fn.getLoc(), revFuncType); + enzyme::YieldOp::create(ruleBuilder, fn.getLoc(), ValueRange{}); ruleBuilder.setInsertionPoint(reverse); @@ -431,13 +431,13 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( gutils->createReverseModeBlocks(fn.getFunctionBody(), reverse.getBody()); gutils->registerCacheCreatorHook([&](Type ty) -> std::pair { - Value cache = ruleBuilder.create(fn.getLoc(), ty); + Value cache = enzyme::InitOp::create(ruleBuilder, fn.getLoc(), ty); return {cache, cache}; }); gutils->registerGradientCreatorHook([&](Location loc, Type ty) -> Value { auto reverseEntry = &reverse.getBody().front(); OpBuilder gBuilder(reverseEntry, reverseEntry->begin()); - return gBuilder.create(loc, ty); + return enzyme::InitOp::create(gBuilder, loc, ty); }); bool valid = true; @@ -474,20 +474,19 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( toYield.push_back(gutils->diffe(arg, rBuilder)); } } - rBuilder.create(fn.getLoc(), toYield); + enzyme::YieldOp::create(rBuilder, fn.getLoc(), toYield); } } ruleBuilder.setInsertionPoint(reverse); - auto augmentedPrimal = - ruleBuilder.create( - fn.getLoc(), primalFuncType); + auto augmentedPrimal = enzyme::CustomReverseRuleAugmentedPrimalOp::create( + ruleBuilder, fn.getLoc(), primalFuncType); augmentedPrimal.getBody().takeBody(newFunc.getFunctionBody()); for (Block &b : augmentedPrimal.getBody()) { if (b.getNumSuccessors() == 0) { Operation *term = b.getTerminator(); OpBuilder builder(term); - builder.create(term->getLoc(), term->getOperands()); + enzyme::YieldOp::create(builder, term->getLoc(), term->getOperands()); term->erase(); } } diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchToTensorPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchToTensorPass.cpp index 576057b261f..7b8c5cf1335 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeBatchToTensorPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeBatchToTensorPass.cpp @@ -115,9 +115,9 @@ struct ConcatOpConversion : public OpConversionPattern { if (inRankTy.isDynamicDim(i)) { // extract dynamic dim Value dynIdx = - rewriter.create(op->getLoc(), i); + arith::ConstantIndexOp::create(rewriter, op->getLoc(), i); Value dynVal = - rewriter.create(op->getLoc(), in, dynIdx); + tensor::DimOp::create(rewriter, op->getLoc(), in, dynIdx); outDynamicDims.push_back(dynVal); } } @@ -126,9 +126,9 @@ struct ConcatOpConversion : public OpConversionPattern { auto outStaticDimAttr = rewriter.getDenseI64ArrayAttr(newInTy.getShape()); - Value newInput = rewriter.create( - op->getLoc(), newInTy, in, reassociationAttr, outDynamicDims, - outStaticDimAttr); + Value newInput = tensor::ExpandShapeOp::create( + rewriter, op->getLoc(), newInTy, in, reassociationAttr, + outDynamicDims, outStaticDimAttr); expandedInputs.push_back(newInput); } diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index 50fd71cc367..ee7be7602b7 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -483,8 +483,9 @@ struct DifferentiatePass CI.getStrongZero()); OpBuilder builder(CI); - auto primalCall = builder.create( - CI.getLoc(), CI->getResultTypes(), ruleToCall, CI.getOperands()); + auto primalCall = enzyme::CallAugmentedPrimalOp::create( + builder, CI.getLoc(), CI->getResultTypes(), ruleToCall, + CI.getOperands()); for (auto [oldRes, newRes] : llvm::zip_equal(CI->getResults(), primalCall.getResults())) { oldRes.replaceAllUsesWith(newRes); @@ -500,8 +501,8 @@ struct DifferentiatePass dyn_cast(tapeUser)) { OpBuilder builder(revCall); - auto newRevCall = builder.create( - revCall.getLoc(), revCall.getResultTypes(), ruleToCall, + auto newRevCall = enzyme::CallCustomReverseOp::create( + builder, revCall.getLoc(), revCall.getResultTypes(), ruleToCall, revCall.getInputs(), tape); revCall.replaceAllUsesWith(newRevCall.getResults()); diff --git a/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp index d582a62be9a..21611a2ccce 100644 --- a/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp @@ -187,7 +187,7 @@ lowerCustomReverseRuleToFunc(enzyme::CustomReverseRuleOp revRule) { toReturn.push_back(info.pushOp.getValue()); info.pushOp->erase(); } - builder.create(term->getLoc(), toReturn); + func::ReturnOp::create(builder, term->getLoc(), toReturn); term->erase(); } } @@ -206,7 +206,7 @@ lowerCustomReverseRuleToFunc(enzyme::CustomReverseRuleOp revRule) { Operation *term = b.getTerminator(); if (isa(term)) { OpBuilder builder(term); - builder.create(term->getLoc(), term->getOperands()); + func::ReturnOp::create(builder, term->getLoc(), term->getOperands()); term->erase(); } } @@ -237,8 +237,8 @@ lowerCustomReverseRuleToFunc(enzyme::CustomReverseRuleOp revRule) { continue; OpBuilder builder(CAP); - auto primalCall = builder.create(CAP.getLoc(), primalFunc, - CAP->getOperands()); + auto primalCall = func::CallOp::create(builder, CAP.getLoc(), primalFunc, + CAP->getOperands()); for (auto [oldRes, newRes] : llvm::zip(CAP.getOutputs(), primalCall->getResults())) { oldRes.replaceAllUsesWith(newRes); @@ -260,7 +260,7 @@ lowerCustomReverseRuleToFunc(enzyme::CustomReverseRuleOp revRule) { .slice(revRule.getFunctionType().getNumResults(), caches.size()) .end()); auto reverseCall = - builder.create(CCR.getLoc(), reverseFunc, operands); + func::CallOp::create(builder, CCR.getLoc(), reverseFunc, operands); for (auto [oldRes, newRes] : llvm::zip(CCR.getResults(), reverseCall.getResults())) { oldRes.replaceAllUsesWith(newRes); From eadc5ad152103a65a757bbfdf238e566e4351fd7 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:21:55 +0100 Subject: [PATCH 20/26] Run mincut cache when lowering to func --- .../MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp | 11 +++++++++++ enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp index 21611a2ccce..55afea08d51 100644 --- a/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/LowerEnzymeCustomRulesToFuncPass.cpp @@ -15,6 +15,7 @@ #include "Passes/Passes.h" #include "Passes/RemovalUtils.h" +#include "RemovalUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -152,6 +153,16 @@ lowerCustomReverseRuleToFunc(enzyme::CustomReverseRuleOp revRule) { } } + if (singleBlock && !revRule->hasAttr("enzyme.disable_mincut")) { + Block *fwdBlock = &primal.getBody().front(), + *bwdBlock = &reverse.getBody().front(); + + IRMapping fwdrevmap; + PatternRewriter rewriter(primal.getContext()); + rewriter.setInsertionPointToStart(bwdBlock); + mlir::enzyme::minCutCache(fwdBlock, bwdBlock, caches, rewriter, fwdrevmap); + } + primalResultTypes.append(cacheTypes.begin(), cacheTypes.end()); reverseArgTypes.append(cacheTypes.begin(), cacheTypes.end()); diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp index e15d27332db..e925bf231c0 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp @@ -602,7 +602,7 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse, int64_t newSize = computeSizeOfType(candidate), newRank = computeRankOfType(candidate); - if (newRank < curRank || (newRank == curRank && newSize < curSize)) { + if (newRank <= curRank || (newRank == curRank && newSize <= curSize)) { newCaches.remove(cur); newCaches.insert(candidate); todo.push_back(candidate); From fe357e87950f33216436de60f1db3962ce72b59a Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Wed, 5 Nov 2025 11:40:07 +0100 Subject: [PATCH 21/26] Small mincut fix --- enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp index e925bf231c0..768d29c20db 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp @@ -452,7 +452,7 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse, if (v.getParentBlock() != reverse) { continue; } - if (G.contains(Node(v))) { + if (v.getDefiningOp() || G.contains(Node(v))) { continue; } Required.insert(op); From 6d4b54a68f1c188daf499a0afcf94cb9a26d2b99 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Wed, 5 Nov 2025 13:13:19 +0100 Subject: [PATCH 22/26] ok --- enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp index 768d29c20db..d96c565584c 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp @@ -399,6 +399,7 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse, for (auto user : poped.getUsers()) { if (user->getBlock() != reverse || !isMovable(user)) { G[info.pushedValue()].insert(Node(user)); + llvm::errs() << "adding required = " << user << "\n"; Required.insert(user); isRequired = true; break; @@ -452,7 +453,11 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse, if (v.getParentBlock() != reverse) { continue; } - if (v.getDefiningOp() || G.contains(Node(v))) { + if (v.getDefiningOp()) { + // Poped value would be part of the graph through the pushed value. + continue; + } + if (G.contains(Node(v))) { continue; } Required.insert(op); From 98506734a8cdff9bb38657686f3e685baa0d6772 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Wed, 5 Nov 2025 13:30:50 +0100 Subject: [PATCH 23/26] fmt --- enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp index d96c565584c..2c292b12a03 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp @@ -454,7 +454,7 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse, continue; } if (v.getDefiningOp()) { - // Poped value would be part of the graph through the pushed value. + // Poped value would be part of the graph through the pushed value. continue; } if (G.contains(Node(v))) { From c578206b7ae24b8e14c206b7db9ce75d0f8d09de Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Wed, 5 Nov 2025 13:31:05 +0100 Subject: [PATCH 24/26] remove unused variable --- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 62d4abe78bc..5c2bffaccf7 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -442,7 +442,6 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( bool valid = true; for (auto &oBB : fn.getFunctionBody()) { - Block *newBB = gutils->getNewFromOriginal(&oBB); Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(&oBB); if (oBB.getNumSuccessors() == 0) { Operation *oTerm = oBB.getTerminator(); From e06f4a47908c6cd893d74bade52d106445d895f4 Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Fri, 14 Nov 2025 03:22:12 -0600 Subject: [PATCH 25/26] Track tape usage through cache --- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 20 ++++-- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 69 ++++++++++--------- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index baefbc68622..47ef4c5e925 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -412,14 +412,26 @@ FlatSymbolRefAttr MEnzymeLogic::CreateSplitModeDiff( strongZero); gutils->createReverseModeBlocks(fn.getFunctionBody(), reverse.getBody()); - gutils->registerCacheCreatorHook([&](Type ty) -> std::pair { + std::function(Type)> cacheCreatorHook = + [&](Type ty) -> std::pair { Value cache = enzyme::InitOp::create(ruleBuilder, fn.getLoc(), ty); return {cache, cache}; - }); - gutils->registerGradientCreatorHook([&](Location loc, Type ty) -> Value { + }; + gutils->registerCacheCreatorHook(cacheCreatorHook); + std::function gradientCreatorHook = + [&](Location loc, Type ty) -> Value { + auto shadowType = + cast(ty).getShadowType(gutils->width); + auto gradientType = enzyme::GradientType::get(ty.getContext(), shadowType); auto reverseEntry = &reverse.getBody().front(); OpBuilder gBuilder(reverseEntry, reverseEntry->begin()); - return enzyme::InitOp::create(gBuilder, loc, ty); + return enzyme::InitOp::create(gBuilder, loc, gradientType); + }; + gutils->registerGradientCreatorHook(gradientCreatorHook); + + auto scope = llvm::make_scope_exit([&]() { + gutils->deregisterCacheCreatorHook(cacheCreatorHook); + gutils->deregisterGradientCreatorHook(gradientCreatorHook); }); bool valid = true; diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index 2da8d188e5f..7fbd3953b32 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -14,6 +14,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "PassDetails.h" #include "Passes/Passes.h" +#include "Passes/RemovalUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -220,8 +221,8 @@ struct DifferentiatePass { for (auto act : CI.getActivity()) { if (call_idx >= CI.getInputs().size()) { - llvm::errs() << "Too few arguments to autodiff op" - << " CI: " << CI << "\n"; + llvm::errs() << "Too few arguments to autodiff op" << " CI: " << CI + << "\n"; return failure(); } mlir::Value res = CI.getInputs()[call_idx]; @@ -256,8 +257,8 @@ struct DifferentiatePass args.push_back(res); if (ty == DIFFE_TYPE::DUP_ARG || ty == DIFFE_TYPE::DUP_NONEED) { if (call_idx >= CI.getInputs().size()) { - llvm::errs() << "Too few arguments to autodiff op" - << "CI: " << CI << "\n"; + llvm::errs() << "Too few arguments to autodiff op" << "CI: " << CI + << "\n"; return failure(); } res = CI.getInputs()[call_idx]; @@ -307,8 +308,8 @@ struct DifferentiatePass returnShadows.push_back(false); if (ty == DIFFE_TYPE::OUT_DIFF) { if (call_idx >= CI.getInputs().size()) { - llvm::errs() << "Too few arguments to autodiff op" - << "CI: " << CI << "\n"; + llvm::errs() << "Too few arguments to autodiff op" << "CI: " << CI + << "\n"; return failure(); } mlir::Value res = CI.getInputs()[call_idx]; @@ -356,16 +357,6 @@ struct DifferentiatePass enzyme::AutoDiffSplitModePrimalOp CI) { auto tape = CI.getTape(); - SmallVector reverseCalls; - for (auto user : tape.getUsers()) { - if (isa(user)) - reverseCalls.push_back(user); - else { - user->emitError() << "todo: unsupported tape usage"; - return failure(); - } - } - auto &symbTable = symbolTable.getSymbolTable(SymbolTable::getNearestSymbolTable(CI)); @@ -392,8 +383,8 @@ struct DifferentiatePass { for (auto act : CI.getActivity()) { if (call_idx >= CI.getInputs().size()) { - llvm::errs() << "Too few arguments to autodiff op" - << " CI: " << CI << "\n"; + llvm::errs() << "Too few arguments to autodiff op" << " CI: " << CI + << "\n"; return failure(); } mlir::Value res = CI.getInputs()[call_idx]; @@ -498,21 +489,33 @@ struct DifferentiatePass SetVector toDelete; tape = primalCall.getTape(); - for (auto tapeUser : tape.getUsers()) { - if (auto revCall = - dyn_cast(tapeUser)) { - - OpBuilder builder(revCall); - auto newRevCall = enzyme::CallCustomReverseOp::create( - builder, revCall.getLoc(), revCall.getResultTypes(), ruleToCall, - revCall.getInputs(), tape); - revCall.replaceAllUsesWith(newRevCall.getResults()); - - toDelete.insert(revCall); - } else { - tapeUser->emitError() - << "todo: support tape going through this operation"; - return failure(); + + SmallVector tapeWorklist = {tape}; + while (!tapeWorklist.empty()) { + tape = tapeWorklist.back(); + tapeWorklist.pop_back(); + for (auto tapeUser : tape.getUsers()) { + if (auto revCall = + dyn_cast(tapeUser)) { + + OpBuilder builder(revCall); + auto newRevCall = enzyme::CallCustomReverseOp::create( + builder, revCall.getLoc(), revCall.getResultTypes(), ruleToCall, + revCall.getInputs(), tape); + revCall.replaceAllUsesWith(newRevCall.getResults()); + + toDelete.insert(revCall); + } else if (auto pushOp = dyn_cast(tapeUser)) { + assert(pushOp.getValue() == tape); + + CacheInfo info(pushOp.getCache()); + + tapeWorklist.push_back(info.popOp.getResult()); + } else { + tapeUser->emitError() + << "todo: support tape going through this operation"; + return failure(); + } } } From 460f0539c3eb9a5a8e2b6418e7a9e8557448fbda Mon Sep 17 00:00:00 2001 From: Paul Berg <9824244+Pangoraw@users.noreply.github.com> Date: Fri, 14 Nov 2025 11:11:22 -0600 Subject: [PATCH 26/26] Make tensor mutable if element type is mutable --- .../Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index dae3a9bc69d..660fbaa2e82 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -138,7 +138,13 @@ class TensorTypeInterface return batchType(self, width); } - bool isMutable(Type self) const { return false; } + bool isMutable(Type self) const { + auto tenType = cast(self); + auto ET = tenType.getElementType(); + auto iface = cast(ET); + return iface.isMutable(); + } + LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { return failure();