Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 77 additions & 20 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,15 @@ def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]>
}

def YieldOp : Enzyme_Op<"yield", [Pure, ReturnLike, Terminator,
ParentOneOf<["AutoDiffRegionOp", "LoopOp"]>]> {
let summary = "Yield values at the end of an autodiff_region or loop op";
ParentOneOf<["AutoDiffRegionOp", "ForLoopOp", "WhileLoopOp"]>]> {
let summary = "Yield values at the end of an autodiff_region or loop ops";
let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = [{
attr-dict ($operands^ `:` type($operands))?
}];
}

def LoopOp : Enzyme_Op<"loop", [AutomaticAllocationScope]> {
def ForLoopOp : Enzyme_Op<"for_loop", [AutomaticAllocationScope]> {
let summary = "Counted loop for probabilistic programming";
let description = [{
A counted loop operation that iterates from `lowerBound` to `upperBound`
Expand Down Expand Up @@ -549,7 +549,7 @@ def SimulateOp : Enzyme_Op<"simulate", [DeclareOpInterfaceMethods<SymbolUserOpIn
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Trace:$trace, AnyType:$weight, Variadic<AnyType>:$outputs);
let results = (outs Trace:$trace, AnyRankedTensor:$weight, Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
Expand All @@ -575,7 +575,7 @@ def GenerateOp : Enzyme_Op<"generate", [DeclareOpInterfaceMethods<SymbolUserOpIn
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Trace:$trace, AnyType:$weight, Variadic<AnyType>:$outputs);
let results = (outs Trace:$trace, AnyRankedTensor:$weight, Variadic<AnyType>:$outputs);

let assemblyFormat = [{
$fn `(` $inputs `)` `given` $constraint attr-dict `:` functional-type($inputs, results)
Expand Down Expand Up @@ -730,6 +730,21 @@ def RandomOp : Enzyme_Op<"random"> {
}];
}

def RandomSplitOp : Enzyme_Op<"randomSplit"> {
let summary = "Split RNG state into multiple independent states";
let description = [{
Splits an RNG state into multiple independent RNG states.
Reference: https://github.com/jax-ml/jax/blob/c25e095fcec9678a4ce5f723afce0c6a3c48a5e7/jax/_src/random.py#L281-L294
}];

let arguments = (ins AnyType:$rng_state);
let results = (outs Variadic<AnyType>:$output_rng_states);

let assemblyFormat = [{
$rng_state attr-dict `:` functional-type(operands, results)
}];
}

def GetSubtraceOp : Enzyme_Op<"getSubtrace", [Pure]> {
let summary = "Get a subtrace from a trace for a given symbol";
let description = [{
Expand Down Expand Up @@ -765,7 +780,7 @@ def GetWeightFromTraceOp : Enzyme_Op<"getWeightFromTrace", [Pure]> {
}];
let arguments = (ins Trace:$trace);

let results = (outs AnyType:$weight);
let results = (outs AnyRankedTensor:$weight);

let assemblyFormat = [{
$trace attr-dict `:` type($weight)
Expand Down Expand Up @@ -818,12 +833,12 @@ def UpdateOp : Enzyme_Op<"update", [DeclareOpInterfaceMethods<SymbolUserOpInterf
FlatSymbolRefAttr:$fn,
Variadic<AnyType>:$inputs,
Trace:$original_trace,
AnyType:$position,
AnyRankedTensor:$position,
AddressArrayAttr:$selection,
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Trace:$updated_trace, AnyType:$weight, AnyType:$output_rng_state);
let results = (outs Trace:$updated_trace, AnyRankedTensor:$weight, AnyType:$output_rng_state);

let assemblyFormat = [{
$fn `(` $inputs `)` `given` $original_trace `at` $position attr-dict `:` functional-type(operands, results)
Expand All @@ -847,7 +862,7 @@ def RegenerateOp : Enzyme_Op<"regenerate", [DeclareOpInterfaceMethods<SymbolUser
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Trace:$trace, AnyType:$weight, AnyType:$output_rng_state);
let results = (outs Trace:$trace, AnyRankedTensor:$weight, AnyType:$output_rng_state);

let assemblyFormat = [{
$fn `(` $inputs `)` `given` $original_trace attr-dict `:` functional-type($inputs, results)
Expand All @@ -872,7 +887,7 @@ def MHOp : Enzyme_Op<"mh", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Trace:$new_trace, AnyType:$accepted, AnyType:$output_rng_state);
let results = (outs Trace:$new_trace, AnyRankedTensor:$accepted, AnyType:$output_rng_state);

let assemblyFormat = [{
$fn `(` $inputs `)` `given` $original_trace attr-dict `:` functional-type($inputs, results)
Expand All @@ -889,7 +904,7 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
and the 0th operand in results is the updated RNG state.

Optional HMC-specific parameters:
- mass: Mass matrix (identity assumed if not provided)
- inverse_mass_matrix: Inverse mass matrix (identity assumed if not provided).
- step_size: Leapfrong integration step size
- num_steps: Number of leapfrog steps
- initial_momentum: deterministic initial momentum (debug)
Expand All @@ -901,18 +916,18 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
Variadic<AnyType>:$inputs,
Trace:$original_trace,
AddressArrayAttr:$selection,
Optional<AnyType>:$mass,
Optional<AnyType>:$step_size,
Optional<AnyType>:$num_steps,
Optional<AnyType>:$initial_momentum,
Optional<AnyRankedTensor>:$inverse_mass_matrix,
Optional<AnyRankedTensor>:$step_size,
Optional<AnyRankedTensor>:$num_steps,
Optional<AnyRankedTensor>:$initial_momentum,
DefaultValuedStrAttr<StrAttr, "">:$name
);

let results = (outs Trace:$new_trace, AnyType:$accepted, AnyType:$output_rng_state);
let results = (outs Trace:$new_trace, AnyRankedTensor:$accepted, AnyType:$output_rng_state);

let assemblyFormat = [{
`algorithm` `=` $alg $fn `(` $inputs `)` `given` $original_trace
(`mass` `=` $mass^ `:` type($mass))?
(`inverse_mass_matrix` `=` $inverse_mass_matrix^ `:` type($inverse_mass_matrix))?
(`step_size` `=` $step_size^ `:` type($step_size))?
(`num_steps` `=` $num_steps^ `:` type($num_steps))?
(`initial_momentum` `=` $initial_momentum^ `:` type($initial_momentum))?
Expand All @@ -921,12 +936,19 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
}

def DotOp : Enzyme_Op<"dot", [Pure]> {
let summary = "Compute dot product of two vectors";
let summary = "Computes a general dot product operation";
let description = [{
Computes the dot product of two 1D tensors (vectors).
Computes a general dot product operation. To be lowered to `stablehlo.dot_general`.
}];

let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs);
let arguments = (ins
AnyRankedTensor:$lhs,
AnyRankedTensor:$rhs,
DenseI64ArrayAttr:$lhs_batching_dimensions,
DenseI64ArrayAttr:$rhs_batching_dimensions,
DenseI64ArrayAttr:$lhs_contracting_dimensions,
DenseI64ArrayAttr:$rhs_contracting_dimensions
);
let results = (outs AnyRankedTensor:$result);

let assemblyFormat = [{
Expand Down Expand Up @@ -986,4 +1008,39 @@ def DumpOp : Enzyme_Op<"dump"> {
}];
}

def WhileLoopOp : Enzyme_Op<"while_loop", [AutomaticAllocationScope]> {
let summary = "While loop with condition";
let description = [{
A while loop operation that continues iterating as long as the condition
evaluates to true. Intended to be lowered to `stablehlo.while`.
}];

let arguments = (ins Variadic<AnyType>:$initArgs);
let regions = (region SizedRegion<1>:$conditionRegion,
SizedRegion<1>:$bodyRegion);
let results = (outs Variadic<AnyType>:$results);

let assemblyFormat = [{
`(` $initArgs `:` type($initArgs) `)`
`->` type(results)
`condition` $conditionRegion
`body` $bodyRegion
attr-dict
}];
}

def LogAddExpOp : Enzyme_Op<"log_add_exp", [Pure]> {
let summary = "Computes log(exp(x) + exp(y))";
let description = [{
Computes log(exp(x) + exp(y)).
}];

let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs);
let results = (outs AnyRankedTensor:$result);

let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` functional-type(operands, results)
}];
}

#endif // ENZYME_OPS
6 changes: 3 additions & 3 deletions enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def ProbProgPass : Pass<"probprog"> {
/*description=*/"Optimization passes to apply to generated probabilistic programs"
>,
Option<
/*C++ variable name=*/"debugMCMC",
/*CLI argument=*/"debug-mcmc",
/*C++ variable name=*/"debugDump",
/*CLI argument=*/"debug-dump",
/*type=*/"bool",
/*default=*/"false",
/*description=*/"Enable debug prints for MCMC algorithms"
/*description=*/"Enable debug dump"
>,
];
}
Expand Down
Loading
Loading