Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
99bd843
split mode
Pangoraw Oct 15, 2025
f6a3ddc
wip: Split mode and custom reverse rules
Pangoraw Oct 29, 2025
8e34471
wip: lower custom rule to func.
Pangoraw Oct 29, 2025
307c91d
tape
Pangoraw Oct 29, 2025
0beea78
Math derivatives
Pangoraw Oct 30, 2025
d35def9
update
Pangoraw Oct 30, 2025
7fb5ad7
Merge remote-tracking branch 'upstream/main' into split-mode
Pangoraw Oct 31, 2025
75a1af7
tape type
Pangoraw Oct 31, 2025
9590a19
fmt
Pangoraw Oct 31, 2025
07dc696
Merge remote-tracking branch 'upstream/main' into split-mode
Pangoraw Nov 3, 2025
6a7c0cf
update
Pangoraw Nov 3, 2025
122f9cc
dup args
Pangoraw Nov 3, 2025
9d16ca3
Shadow type
Pangoraw Nov 3, 2025
29c0bdd
fmt
Pangoraw Nov 3, 2025
d78841f
Symbol Interface for custom rule
Pangoraw Nov 3, 2025
0cf00d6
Cleanup
Pangoraw Nov 3, 2025
be8f242
Each subop has its own function type and custom parser/printer
Pangoraw Nov 3, 2025
5f8c9c1
fmt
Pangoraw Nov 3, 2025
83c885c
Model side effects in LLVMExt
Pangoraw Nov 3, 2025
a1b6d32
cmake
Pangoraw Nov 4, 2025
e74f2dc
Merge remote-tracking branch 'upstream/main' into split-mode
Pangoraw Nov 5, 2025
a8b4087
Operation::create
Pangoraw Nov 5, 2025
eadc5ad
Run mincut cache when lowering to func
Pangoraw Nov 5, 2025
fe357e8
Small mincut fix
Pangoraw Nov 5, 2025
6d4b54a
ok
Pangoraw Nov 5, 2025
9850673
fmt
Pangoraw Nov 5, 2025
c578206
remove unused variable
Pangoraw Nov 5, 2025
81ebb44
Merge branch 'main' into split-mode
Pangoraw Nov 13, 2025
e06f4a4
Track tape usage through cache
Pangoraw Nov 14, 2025
460f053
Make tensor mutable if element type is mutable
Pangoraw Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ td_library(
includes = ["."],
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfaces"
]
)

Expand Down Expand Up @@ -319,7 +320,10 @@ gentbl_cc_library(
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Dialect/LLVMExt/LLVMExtOps.td",
deps = [":LLVMExtDialectTdFiles"],
deps = [
":LLVMExtDialectTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles"
]
)

gentbl_cc_library(
Expand Down
95 changes: 93 additions & 2 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,88 @@ def AutoDiffOp : Enzyme_Op<"autodiff",
}];
}

def AnyTape : Type<CPred<"::llvm::isa<mlir::enzyme::TapeType>($_self)">, "enzyme tape">;

def AutoDiffSplitModePrimalOp : Enzyme_Op<"autodiff_split_mode.primal",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Runs an augmented primal that can later be used to generate reverse with enzyme.autodiff_deferred_reverse";
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr<I64Attr, "1">:$width, DefaultValuedAttr<BoolAttr, "false">:$strong_zero);
let results = (outs Variadic<AnyType>:$outputs, AnyTape:$tape);

let assemblyFormat = [{
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
}];
}

def AutoDiffSplitModeReverseOp : Enzyme_Op<"autodiff_split_mode.reverse",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Runs the reverse from an enzyme.autodiff_deferred_primal result";
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, AnyTape:$tape,
ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity);
let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$fn `(` operands `)` attr-dict `:` functional-type(operands, results)
}];
}

def CustomReverseRuleOp : Enzyme_Op<"custom_reverse_rule", [IsolatedFromAbove, Symbol]> {
let summary = "Parent operation for custom reverse rule declaration.";
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttrOf<FunctionType>:$function_type, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity);
let regions = (region AnyRegion:$body);
let results = (outs);

let assemblyFormat = [{
$sym_name $body attr-dict-with-keyword
}];
}

def CustomReverseRuleAugmentedPrimalOp : Enzyme_Op<"custom_reverse_rule.augmented_primal", [
HasParent<"CustomReverseRuleOp">,
AutomaticAllocationScope,
AffineScope]> {
let summary = "Defines the augmented primal for a custom reverse rule";
let arguments = (ins TypeAttrOf<FunctionType>:$function_type);
let regions = (region AnyRegion:$body);
let results = (outs);

let hasCustomAssemblyFormat = 1;
}

def CustomReverseRuleReverseOp : Enzyme_Op<"custom_reverse_rule.reverse", [
HasParent<"CustomReverseRuleOp">,
AutomaticAllocationScope,
AffineScope]> {
let summary = "Defined the reverse for a custom rule.";
let arguments = (ins TypeAttrOf<FunctionType>:$function_type);
let regions = (region AnyRegion:$body);
let results = (outs);

let hasCustomAssemblyFormat = 1;
}

def CallAugmentedPrimalOp : Enzyme_Op<"call_augmented_primal", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "";
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs);
let results = (outs Variadic<AnyType>:$outputs, AnyTape:$tape);

let assemblyFormat = [{
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
}];
}

def CallCustomReverseOp : Enzyme_Op<"call_custom_reverse", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "";
let arguments = (ins FlatSymbolRefAttr:$fn, Variadic<AnyType>:$inputs, AnyTape:$tape);
let results = (outs Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$fn `(` operands `)` attr-dict `:` functional-type(operands, results)
}];
}

def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]> {
let summary = "Perform reverse mode AD on a child region";
let arguments = (ins Variadic<AnyType>:$inputs, ActivityArrayAttr:$activity, ActivityArrayAttr:$ret_activity, DefaultValuedAttr<I64Attr, "1">:$width, DefaultValuedAttr<BoolAttr, "false">:$strong_zero, OptionalAttr<StrAttr>:$fn);
Expand Down Expand Up @@ -244,8 +326,8 @@ def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]>
}

def YieldOp : Enzyme_Op<"yield", [Pure, ReturnLike, Terminator,
ParentOneOf<["AutoDiffRegionOp", "LoopOp"]>]> {
let summary = "Yield values at the end of an autodiff_region or loop op";
ParentOneOf<["AutoDiffRegionOp", "LoopOp", "CustomReverseRuleReverseOp", "CustomReverseRuleAugmentedPrimalOp", "CustomReverseRuleOp"]>]> {
let summary = "Yield values at the end of an autodiff_region, loop op, reverse op, aug primal op or custom reverse rule op";
let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = [{
attr-dict ($operands^ `:` type($operands))?
Expand Down Expand Up @@ -324,6 +406,15 @@ def Cache : Enzyme_Type<"Cache"> {
let assemblyFormat = "`<` $type `>`";
}

def Tape : Enzyme_Type<"Tape"> {
let summary = "Tape for reverse deferred";
let description = [{
"Tape for reverse deferred"
}];
let parameters = (ins);
let mnemonic = "Tape";
}

def Gradient : Enzyme_Type<"Gradient"> {
let summary = "Mutable storage for accumulating gradients";
let description = [{
Expand Down
5 changes: 3 additions & 2 deletions enzyme/Enzyme/MLIR/Dialect/LLVMExt/LLVMExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
#define ENZYME_DIALECT_LLVMEXT_OPS_TD

include "Dialect.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

def LLVMPtr : Type<CPred<"::llvm::isa<::mlir::LLVM::LLVMPointerType>($_self)">>;

def AllocOp : LLVMExtOp<"alloc", []> {
let summary = "Allocates memory";
let arguments = (ins I64 : $size);
let results = (outs LLVMPtr : $result);
let results = (outs Res<LLVMPtr, "allocated ptr", [MemAlloc]> : $result);

let assemblyFormat = [{
$size attr-dict `:` functional-type($size, results)
Expand All @@ -17,7 +18,7 @@ def AllocOp : LLVMExtOp<"alloc", []> {

def FreeOp : LLVMExtOp<"free", []> {
let summary = "Frees memory";
let arguments = (ins LLVMPtr : $ptr);
let arguments = (ins Arg<LLVMPtr, "ptr to free", [MemFree]> : $ptr);
let results = (outs);

let assemblyFormat = [{
Expand Down
132 changes: 132 additions & 0 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -526,6 +527,137 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}

//===----------------------------------------------------------------------===//
// AutoDiffSplitModePrimalOp
//===----------------------------------------------------------------------===//

LogicalResult AutoDiffSplitModePrimalOp::verifySymbolUses(
SymbolTableCollection &symbolTable) {
// TODO: Verify that the result type is same as the type of the referenced
// func.func op.
auto global =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid global funcOp";

return success();
}

//===----------------------------------------------------------------------===//
// AutoDiffSplitModeReverseOp
//===----------------------------------------------------------------------===//

LogicalResult AutoDiffSplitModeReverseOp::verifySymbolUses(
SymbolTableCollection &symbolTable) {
// TODO: Verify that the result type is same as the type of the referenced
// func.func op.
auto global =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid global funcOp";

return success();
}

static ParseResult parseAugmentedFn(OpAsmParser &parser,
OperationState &result) {
SmallVector<Type> argTys, resTys;
SmallVector<DictionaryAttr> resAttrs;

bool isVariadic = false;
SmallVector<OpAsmParser::Argument> arguments;
if (failed(function_interface_impl::parseFunctionSignatureWithArguments(
parser, /*allowVariadic*/ false, arguments, isVariadic, resTys,
resAttrs)))
return failure();

auto *body = result.addRegion();
if (failed(
parser.parseRegion(*body, arguments, /*enableNameShadowing*/ false)))
return failure();

result.addAttribute(
"function_type",
TypeAttr::get(FunctionType::get(result.getContext(), argTys, resTys)));

return success();
}

static void printAugmentedFn(OpAsmPrinter &p, FunctionType fnType,
Region &body) {
p << ' ';

call_interface_impl::printFunctionSignature(
p, fnType.getInputs(), nullptr, /*isVariadic*/ false, fnType.getResults(),
nullptr, &body, /*printEmptyResult*/ false);

p << ' ';
p.printRegion(body, /*printEntryBlockArgs*/ false,
/*printBlockTerminators*/ true);
}

//===----------------------------------------------------------------------===//
// CustomReverseRuleAugmentedPrimalOp
//===----------------------------------------------------------------------===//

mlir::ParseResult
CustomReverseRuleAugmentedPrimalOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseAugmentedFn(parser, result);
}

void CustomReverseRuleAugmentedPrimalOp::print(OpAsmPrinter &p) {
printAugmentedFn(p, getFunctionType(), getBody());
}

//===----------------------------------------------------------------------===//
// CustomReverseRuleReverseOp
//===----------------------------------------------------------------------===//

mlir::ParseResult CustomReverseRuleReverseOp::parse(OpAsmParser &parser,
OperationState &result) {
return parseAugmentedFn(parser, result);
}

void CustomReverseRuleReverseOp::print(OpAsmPrinter &p) {
auto rule = cast<CustomReverseRuleOp>(this->getParentOp());
printAugmentedFn(p, getFunctionType(), getBody());
}

//===----------------------------------------------------------------------===//
// CallAugmentedPrimalOp
//===----------------------------------------------------------------------===//

LogicalResult
CallAugmentedPrimalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto global =
symbolTable.lookupNearestSymbolFrom<enzyme::CustomReverseRuleOp>(
*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid custom reverse rule";

return success();
}

//===----------------------------------------------------------------------===//
// CallCustomReverseOp
//===----------------------------------------------------------------------===//

LogicalResult
CallCustomReverseOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto global =
symbolTable.lookupNearestSymbolFrom<enzyme::CustomReverseRuleOp>(
*this, getFnAttr());
if (!global)
return emitOpError("'")
<< getFn() << "' does not reference a valid custom reverse rule";

return success();
}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,13 @@ class TensorTypeInterface
return batchType(self, width);
}

bool isMutable(Type self) const { return false; }
bool isMutable(Type self) const {
auto tenType = cast<TensorType>(self);
auto ET = tenType.getElementType();
auto iface = cast<AutoDiffTypeInterface>(ET);
return iface.isMutable();
}

LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
Value val) const {
return failure();
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,14 @@ def CheckedDivF : SubRoutine<(Op $diffret, $x),
def LlvmCheckedMulF : LlvmInst<"FMulOp">;
def LlvmExpF : LlvmInst<"ExpOp">;

def ComplexCreate : ComplexInst<"CreateOp">;
def ComplexRe : ComplexInst<"ReOp">;
def ComplexIm : ComplexInst<"ImOp">;

def CosF : MathInst<"CosOp">;
def SinF : MathInst<"SinOp">;
def ExpF : MathInst<"ExpOp">;
def SqrtF : MathInst<"SqrtOp">;
def AbsF : MathInst<"AbsFOp">;

#endif // ENZYME_MLIR_IMPLEMENTATIONS_COMMON
30 changes: 30 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/ComplexDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,33 @@ def : MLIRDerivative<"complex", "MulOp", (Op $x, $y),
(CMul (DiffeRet), $x)
]
>;

def : MLIRDerivative<"complex", "ImOp", (Op $x),
[
(ComplexCreate
(TypeOf $x),
(ConstantFP<"0", "arith", "ConstantOp">),
(NegF (DiffeRet))
)
],
(ComplexIm (Shadow $x))
>;

def : MLIRDerivative<"complex", "ReOp", (Op $x),
[
(ComplexCreate
(TypeOf $x),
(DiffeRet),
(ConstantFP<"0", "arith", "ConstantOp">)
)
],
(ComplexRe (Shadow $x))
>;

def : MLIRDerivative<"complex", "CreateOp", (Op $re, $im),
[
(ComplexRe (DiffeRet)),
(ComplexIm (DiffeRet))
]
>;

Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"
Expand Down
Loading
Loading