diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 5e197c21128..80ddbb7cfd7 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -244,15 +244,15 @@ def AutoDiffRegionOp : Enzyme_Op<"autodiff_region", [AutomaticAllocationScope]> } def YieldOp : Enzyme_Op<"yield", [Pure, ReturnLike, Terminator, - ParentOneOf<["AutoDiffRegionOp", "LoopOp"]>]> { - let summary = "Yield values at the end of an autodiff_region or loop op"; + ParentOneOf<["AutoDiffRegionOp", "ForLoopOp", "WhileLoopOp"]>]> { + let summary = "Yield values at the end of an autodiff_region or loop ops"; let arguments = (ins Variadic:$operands); let assemblyFormat = [{ attr-dict ($operands^ `:` type($operands))? }]; } -def LoopOp : Enzyme_Op<"loop", [AutomaticAllocationScope]> { +def ForLoopOp : Enzyme_Op<"for_loop", [AutomaticAllocationScope]> { let summary = "Counted loop for probabilistic programming"; let description = [{ A counted loop operation that iterates from `lowerBound` to `upperBound` @@ -549,7 +549,7 @@ def SimulateOp : Enzyme_Op<"simulate", [DeclareOpInterfaceMethods:$name ); - let results = (outs Trace:$trace, AnyType:$weight, Variadic:$outputs); + let results = (outs Trace:$trace, AnyRankedTensor:$weight, Variadic:$outputs); let assemblyFormat = [{ $fn `(` $inputs `)` attr-dict `:` functional-type($inputs, results) @@ -575,7 +575,7 @@ def GenerateOp : Enzyme_Op<"generate", [DeclareOpInterfaceMethods:$name ); - let results = (outs Trace:$trace, AnyType:$weight, Variadic:$outputs); + let results = (outs Trace:$trace, AnyRankedTensor:$weight, Variadic:$outputs); let assemblyFormat = [{ $fn `(` $inputs `)` `given` $constraint attr-dict `:` functional-type($inputs, results) @@ -730,6 +730,21 @@ def RandomOp : Enzyme_Op<"random"> { }]; } +def RandomSplitOp : Enzyme_Op<"randomSplit"> { + let summary = "Split RNG state into multiple independent states"; + let description = [{ + Splits an RNG state into multiple independent RNG states. + Reference: https://github.com/jax-ml/jax/blob/c25e095fcec9678a4ce5f723afce0c6a3c48a5e7/jax/_src/random.py#L281-L294 + }]; + + let arguments = (ins AnyType:$rng_state); + let results = (outs Variadic:$output_rng_states); + + let assemblyFormat = [{ + $rng_state attr-dict `:` functional-type(operands, results) + }]; +} + def GetSubtraceOp : Enzyme_Op<"getSubtrace", [Pure]> { let summary = "Get a subtrace from a trace for a given symbol"; let description = [{ @@ -765,7 +780,7 @@ def GetWeightFromTraceOp : Enzyme_Op<"getWeightFromTrace", [Pure]> { }]; let arguments = (ins Trace:$trace); - let results = (outs AnyType:$weight); + let results = (outs AnyRankedTensor:$weight); let assemblyFormat = [{ $trace attr-dict `:` type($weight) @@ -818,12 +833,12 @@ def UpdateOp : Enzyme_Op<"update", [DeclareOpInterfaceMethods:$inputs, Trace:$original_trace, - AnyType:$position, + AnyRankedTensor:$position, AddressArrayAttr:$selection, DefaultValuedStrAttr:$name ); - let results = (outs Trace:$updated_trace, AnyType:$weight, AnyType:$output_rng_state); + let results = (outs Trace:$updated_trace, AnyRankedTensor:$weight, AnyType:$output_rng_state); let assemblyFormat = [{ $fn `(` $inputs `)` `given` $original_trace `at` $position attr-dict `:` functional-type(operands, results) @@ -847,7 +862,7 @@ def RegenerateOp : Enzyme_Op<"regenerate", [DeclareOpInterfaceMethods:$name ); - let results = (outs Trace:$trace, AnyType:$weight, AnyType:$output_rng_state); + let results = (outs Trace:$trace, AnyRankedTensor:$weight, AnyType:$output_rng_state); let assemblyFormat = [{ $fn `(` $inputs `)` `given` $original_trace attr-dict `:` functional-type($inputs, results) @@ -872,7 +887,7 @@ def MHOp : Enzyme_Op<"mh", [DeclareOpInterfaceMethods]> { DefaultValuedStrAttr:$name ); - let results = (outs Trace:$new_trace, AnyType:$accepted, AnyType:$output_rng_state); + let results = (outs Trace:$new_trace, AnyRankedTensor:$accepted, AnyType:$output_rng_state); let assemblyFormat = [{ $fn `(` $inputs `)` `given` $original_trace attr-dict `:` functional-type($inputs, results) @@ -889,7 +904,7 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods and the 0th operand in results is the updated RNG state. Optional HMC-specific parameters: - - mass: Mass matrix (identity assumed if not provided) + - inverse_mass_matrix: Inverse mass matrix (identity assumed if not provided). - step_size: Leapfrong integration step size - num_steps: Number of leapfrog steps - initial_momentum: deterministic initial momentum (debug) @@ -901,18 +916,18 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods Variadic:$inputs, Trace:$original_trace, AddressArrayAttr:$selection, - Optional:$mass, - Optional:$step_size, - Optional:$num_steps, - Optional:$initial_momentum, + Optional:$inverse_mass_matrix, + Optional:$step_size, + Optional:$num_steps, + Optional:$initial_momentum, DefaultValuedStrAttr:$name ); - let results = (outs Trace:$new_trace, AnyType:$accepted, AnyType:$output_rng_state); + let results = (outs Trace:$new_trace, AnyRankedTensor:$accepted, AnyType:$output_rng_state); let assemblyFormat = [{ `algorithm` `=` $alg $fn `(` $inputs `)` `given` $original_trace - (`mass` `=` $mass^ `:` type($mass))? + (`inverse_mass_matrix` `=` $inverse_mass_matrix^ `:` type($inverse_mass_matrix))? (`step_size` `=` $step_size^ `:` type($step_size))? (`num_steps` `=` $num_steps^ `:` type($num_steps))? (`initial_momentum` `=` $initial_momentum^ `:` type($initial_momentum))? @@ -921,12 +936,19 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods } def DotOp : Enzyme_Op<"dot", [Pure]> { - let summary = "Compute dot product of two vectors"; + let summary = "Computes a general dot product operation"; let description = [{ - Computes the dot product of two 1D tensors (vectors). + Computes a general dot product operation. To be lowered to `stablehlo.dot_general`. }]; - let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs); + let arguments = (ins + AnyRankedTensor:$lhs, + AnyRankedTensor:$rhs, + DenseI64ArrayAttr:$lhs_batching_dimensions, + DenseI64ArrayAttr:$rhs_batching_dimensions, + DenseI64ArrayAttr:$lhs_contracting_dimensions, + DenseI64ArrayAttr:$rhs_contracting_dimensions + ); let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ @@ -986,4 +1008,39 @@ def DumpOp : Enzyme_Op<"dump"> { }]; } +def WhileLoopOp : Enzyme_Op<"while_loop", [AutomaticAllocationScope]> { + let summary = "While loop with condition"; + let description = [{ + A while loop operation that continues iterating as long as the condition + evaluates to true. Intended to be lowered to `stablehlo.while`. + }]; + + let arguments = (ins Variadic:$initArgs); + let regions = (region SizedRegion<1>:$conditionRegion, + SizedRegion<1>:$bodyRegion); + let results = (outs Variadic:$results); + + let assemblyFormat = [{ + `(` $initArgs `:` type($initArgs) `)` + `->` type(results) + `condition` $conditionRegion + `body` $bodyRegion + attr-dict + }]; +} + +def LogAddExpOp : Enzyme_Op<"log_add_exp", [Pure]> { + let summary = "Computes log(exp(x) + exp(y))"; + let description = [{ + Computes log(exp(x) + exp(y)). + }]; + + let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` functional-type(operands, results) + }]; +} + #endif // ENZYME_OPS diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 5c7043d00d7..1f6e9ec1504 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -56,11 +56,11 @@ def ProbProgPass : Pass<"probprog"> { /*description=*/"Optimization passes to apply to generated probabilistic programs" >, Option< - /*C++ variable name=*/"debugMCMC", - /*CLI argument=*/"debug-mcmc", + /*C++ variable name=*/"debugDump", + /*CLI argument=*/"debug-dump", /*type=*/"bool", /*default=*/"false", - /*description=*/"Enable debug prints for MCMC algorithms" + /*description=*/"Enable debug dump" >, ]; } diff --git a/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp index b5c249b4829..558ea600760 100644 --- a/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp @@ -39,6 +39,37 @@ namespace enzyme { namespace { +static Value createIdentityMatrix(OpBuilder &builder, Location loc, + RankedTensorType matrixType) { + auto shape = matrixType.getShape(); + assert(shape.size() == 2 && shape[0] == shape[1] && + "Identity matrix must be square"); + int64_t n = shape[0]; + + SmallVector identityData(n * n, 0.0); + for (int64_t i = 0; i < n; ++i) { + identityData[i * n + i] = 1.0; + } + + return builder.create( + loc, matrixType, + DenseElementsAttr::get(matrixType, ArrayRef(identityData))); +} + +static Value createSigmoid(OpBuilder &builder, Location loc, Value x) { + auto xType = cast(x.getType()); + auto elemType = xType.getElementType(); + + auto oneConst = builder.create( + loc, xType, + DenseElementsAttr::get(xType, builder.getFloatAttr(elemType, 1.0))); + auto negX = builder.create(loc, x); + auto expNegX = builder.create(loc, negX); + auto onePlusExp = builder.create(loc, oneConst, expNegX); + auto result = builder.create(loc, oneConst, onePlusExp); + return result; +} + static bool computePositionSizeForAddress(Operation *op, FunctionOpInterface func, ArrayRef address, @@ -399,11 +430,11 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { }; struct LowerMCMCPattern : public mlir::OpRewritePattern { - bool debug; + bool debugDump; - LowerMCMCPattern(MLIRContext *context, bool debug, + LowerMCMCPattern(MLIRContext *context, bool debugDump, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), debug(debug) {} + : OpRewritePattern(context, benefit), debugDump(debugDump) {} LogicalResult matchAndRewrite(enzyme::MCMCOp mcmcOp, PatternRewriter &rewriter) const override { @@ -413,18 +444,68 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { case enzyme::MCMCAlgorithm::HMC: return lowerHMC(mcmcOp, rewriter); case enzyme::MCMCAlgorithm::NUTS: - mcmcOp.emitError("NUTS lowering not yet implemented"); - return failure(); + return lowerNUTS(mcmcOp, rewriter); default: - mcmcOp.emitError("Unknown MCMC algorithm"); + mcmcOp.emitError("ProbProg: Unknown MCMC algorithm"); return failure(); } } private: + // Reference: + // https://github.com/pyro-ppl/numpyro/blob/d49f71825691b554fb8188f8779dc3a5d13e7b96/numpyro/infer/hmc_util.py#L36 + struct NUTSTree { + Value q_left, p_left, grad_left; + Value q_right, p_right, grad_right; + Value q_proposal, grad_proposal, U_proposal, H_proposal; + Value depth, weight, turning, diverging; + Value sum_accept_probs, num_proposals, p_sum; + + static constexpr size_t NUM_FIELDS = 17; + + SmallVector toValues() const { + return {q_left, p_left, grad_left, + q_right, p_right, grad_right, + q_proposal, grad_proposal, U_proposal, + H_proposal, depth, weight, + turning, diverging, sum_accept_probs, + num_proposals, p_sum}; + } + + static NUTSTree fromValues(ArrayRef values) { + assert(values.size() == NUM_FIELDS); + NUTSTree tree; + tree.q_left = values[0]; + tree.p_left = values[1]; + tree.grad_left = values[2]; + tree.q_right = values[3]; + tree.p_right = values[4]; + tree.grad_right = values[5]; + tree.q_proposal = values[6]; + tree.grad_proposal = values[7]; + tree.U_proposal = values[8]; + tree.H_proposal = values[9]; + tree.depth = values[10]; + tree.weight = values[11]; + tree.turning = values[12]; + tree.diverging = values[13]; + tree.sum_accept_probs = values[14]; + tree.num_proposals = values[15]; + tree.p_sum = values[16]; + return tree; + } + + SmallVector getTypes() const { + SmallVector types; + for (auto val : toValues()) + types.push_back(val.getType()); + return types; + } + }; + Value conditionalDump(OpBuilder &builder, Location loc, Value value, StringRef label) const { - if (debug) { + if (debugDump) { return enzyme::DumpOp::create(builder, loc, value.getType(), value, builder.getStringAttr(label)) .getOutput(); @@ -455,7 +536,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { return failure(); } - Value mass = mcmcOp.getMass(); + Value invMass = mcmcOp.getInverseMassMatrix(); Value stepSize = mcmcOp.getStepSize(); Value numSteps = mcmcOp.getNumSteps(); @@ -508,15 +589,49 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { p0 = initialMomentum; rng1 = rngState; } else { - if (mass) { - auto randomOp = enzyme::RandomOp::create( - rewriter, loc, TypeRange{rngState.getType(), positionType}, - rngState, zeroConst, mass, - enzyme::RngDistributionAttr::get( - rewriter.getContext(), enzyme::RngDistribution::MULTINORMAL)); - rng1 = randomOp.getOutputRngState(); - p0 = randomOp.getResult(); + if (invMass) { + auto invMassType = cast(invMass.getType()); + if (invMassType.getRank() == 1) { + // Diagonal case: p0 = 1 / sqrt(invMass) * eps where eps ~ N(0, I) + auto sqrtInvMass = math::SqrtOp::create(rewriter, loc, invMass); + auto onesVector = arith::ConstantOp::create( + rewriter, loc, invMassType, + DenseElementsAttr::get(invMassType, + rewriter.getF64FloatAttr(1.0))); + auto massMatrixSqrt = + arith::DivFOp::create(rewriter, loc, onesVector, sqrtInvMass); + auto randomOp = enzyme::RandomOp::create( + rewriter, loc, TypeRange{rngState.getType(), positionType}, + rngState, zeroConst, oneConst, + enzyme::RngDistributionAttr::get( + rewriter.getContext(), enzyme::RngDistribution::NORMAL)); + rng1 = randomOp.getOutputRngState(); + auto eps = randomOp.getResult(); + p0 = arith::MulFOp::create(rewriter, loc, massMatrixSqrt, eps); + } else { + // Dense case: p0 = mass_matrix_sqrt @ eps where eps ~ N(0, I) + auto identityMatrix = + createIdentityMatrix(rewriter, loc, invMassType); + auto massMatrixSqrt = enzyme::CholeskySolveOp::create( + rewriter, loc, invMassType, invMass, identityMatrix); + auto randomOp = enzyme::RandomOp::create( + rewriter, loc, TypeRange{rngState.getType(), positionType}, + rngState, zeroConst, oneConst, + enzyme::RngDistributionAttr::get( + rewriter.getContext(), enzyme::RngDistribution::NORMAL)); + rng1 = randomOp.getOutputRngState(); + auto eps = randomOp.getResult(); + p0 = enzyme::DotOp::create( + rewriter, loc, positionType, massMatrixSqrt, eps, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/ + rewriter.getDenseI64ArrayAttr({1}), + /*rhs_contracting_dimensions=*/ + rewriter.getDenseI64ArrayAttr({0})); + } } else { + // Assume identity mass matrix: p0 ~ N(0, I) auto randomOp = enzyme::RandomOp::create( rewriter, loc, TypeRange{rngState.getType(), positionType}, rngState, zeroConst, oneConst, @@ -531,19 +646,41 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { rewriter, loc, tensorType, DenseElementsAttr::get(tensorType, rewriter.getF64FloatAttr(0.5))); - // 4. Compute initial kinetic energy K0 = 0.5 * p^T * M^{-1} * p + // 4. Compute initial kinetic energy K0 = 0.5 * p^T * M^-1 * p Value K0; - if (mass) { - auto MInvP0 = enzyme::CholeskySolveOp::create(rewriter, loc, - positionType, mass, p0); - auto p0DotMInvP = - enzyme::DotOp::create(rewriter, loc, tensorType, p0, MInvP0); + if (invMass) { + auto invMassType = cast(invMass.getType()); + Value invMassP0; + + if (invMassType.getRank() == 1) { + invMassP0 = arith::MulFOp::create(rewriter, loc, invMass, p0); + } else { + invMassP0 = enzyme::DotOp::create( + rewriter, loc, positionType, invMass, p0, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({1}), + /*rhs_contracting_dimensions=*/ + rewriter.getDenseI64ArrayAttr({0})); + } + + auto p0DotInvMassP0 = enzyme::DotOp::create( + rewriter, loc, tensorType, p0, invMassP0, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}), + /*rhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0})); K0 = conditionalDump( rewriter, loc, - arith::MulFOp::create(rewriter, loc, halfConst, p0DotMInvP), + arith::MulFOp::create(rewriter, loc, halfConst, p0DotInvMassP0), "HMC: initial kinetic energy K0"); } else { - auto p0DotP0 = enzyme::DotOp::create(rewriter, loc, tensorType, p0, p0); + auto p0DotP0 = enzyme::DotOp::create( + rewriter, loc, tensorType, p0, p0, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}), + /*rhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0})); K0 = conditionalDump( rewriter, loc, arith::MulFOp::create(rewriter, loc, halfConst, p0DotP0), @@ -573,7 +710,6 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { Block *autodiffInitBlock = rewriter.createBlock(&autodiffInit.getBody()); autodiffInitBlock->addArgument(positionType, loc); - OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(autodiffInitBlock); Value q0Arg = autodiffInitBlock->getArgument(0); @@ -620,11 +756,11 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { SmallVector loopResultTypes = {positionType, positionType, positionType, rng0_final.getType()}; - auto loopOp = - enzyme::LoopOp::create(rewriter, loc, loopResultTypes, c0, numSteps, - c1, ValueRange{q0, p0, grad0, rng0_final}); + auto forLoopOp = enzyme::ForLoopOp::create( + rewriter, loc, loopResultTypes, c0, numSteps, c1, + ValueRange{q0, p0, grad0, rng0_final}); - Block *loopBody = rewriter.createBlock(&loopOp.getRegion()); + Block *loopBody = rewriter.createBlock(&forLoopOp.getRegion()); loopBody->addArgument(i64TensorType, loc); // iv loopBody->addArgument(positionType, loc); // q loopBody->addArgument(positionType, loc); // p @@ -647,11 +783,22 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { rewriter, loc, arith::SubFOp::create(rewriter, loc, p, deltaP1), "Leapfrog: momentum p(t + eps/2)"); - // 6.2 Full step on position: q += eps * M^{-1} * p1 + // 6.2 Full step on position: q += eps * M^-1 * p1 Value v1; - if (mass) { - v1 = enzyme::CholeskySolveOp::create(rewriter, loc, positionType, mass, - p1); + if (invMass) { + auto invMassType = cast(invMass.getType()); + + if (invMassType.getRank() == 1) { + v1 = arith::MulFOp::create(rewriter, loc, invMass, p1); + } else { + v1 = enzyme::DotOp::create( + rewriter, loc, positionType, invMass, p1, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({1}), + /*rhs_contracting_dimensions=*/ + rewriter.getDenseI64ArrayAttr({0})); + } } else { v1 = p1; } @@ -715,10 +862,10 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { enzyme::YieldOp::create(rewriter, loc, ValueRange{q1, p2, newGradient, newRng}); - rewriter.setInsertionPointAfter(loopOp); - Value qL = loopOp.getResult(0); - Value pL = loopOp.getResult(1); - Value rngAfterLeapfrog = loopOp.getResult(3); + rewriter.setInsertionPointAfter(forLoopOp); + Value qL = forLoopOp.getResult(0); + Value pL = forLoopOp.getResult(1); + Value rngAfterLeapfrog = forLoopOp.getResult(3); // 7. Generate final trace with final position qL SmallVector finalUpdateInputs; @@ -738,19 +885,41 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { rewriter, loc, arith::NegFOp::create(rewriter, loc, weight1), "HMC: final potential energy U1"); - // K1 = 0.5 * pL^T * M^{-1} * pL + // K1 = 0.5 * pL^T * M^-1 * pL Value K1; - if (mass) { - auto MInvPL = enzyme::CholeskySolveOp::create(rewriter, loc, - positionType, mass, pL); - auto pLDotMInvPL = - enzyme::DotOp::create(rewriter, loc, tensorType, pL, MInvPL); + if (invMass) { + auto invMassType = cast(invMass.getType()); + Value invMassPL; + + if (invMassType.getRank() == 1) { + invMassPL = arith::MulFOp::create(rewriter, loc, invMass, pL); + } else { + invMassPL = enzyme::DotOp::create( + rewriter, loc, positionType, invMass, pL, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({1}), + /*rhs_contracting_dimensions=*/ + rewriter.getDenseI64ArrayAttr({0})); + } + + auto pLDotInvMassPL = enzyme::DotOp::create( + rewriter, loc, tensorType, pL, invMassPL, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}), + /*rhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0})); K1 = conditionalDump( rewriter, loc, - arith::MulFOp::create(rewriter, loc, halfConst, pLDotMInvPL), + arith::MulFOp::create(rewriter, loc, halfConst, pLDotInvMassPL), "HMC: final kinetic energy K1"); } else { - auto pLDotPL = enzyme::DotOp::create(rewriter, loc, tensorType, pL, pL); + auto pLDotPL = enzyme::DotOp::create( + rewriter, loc, tensorType, pL, pL, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}), + /*rhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0})); K1 = conditionalDump( rewriter, loc, arith::MulFOp::create(rewriter, loc, halfConst, pLDotPL), @@ -790,6 +959,920 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { return success(); } + + LogicalResult lowerNUTS(enzyme::MCMCOp mcmcOp, + PatternRewriter &rewriter) const { + SymbolTableCollection symbolTable; + + auto fn = cast( + symbolTable.lookupNearestSymbolFrom(mcmcOp, mcmcOp.getFnAttr())); + + if (fn.getFunctionBody().empty()) { + mcmcOp.emitError( + "ProbProg: calling `mcmc` with NUTS on an empty function"); + return failure(); + } + + auto F64TensorType = RankedTensorType::get({}, rewriter.getF64Type()); + auto traceType = enzyme::TraceType::get(mcmcOp.getContext()); + + Value invMass = mcmcOp.getInverseMassMatrix(); + Value stepSize = mcmcOp.getStepSize(); + + if (!stepSize) { + mcmcOp.emitError("ProbProg: NUTS requires step_size parameter"); + return failure(); + } + + auto inputs = mcmcOp.getInputs(); + if (inputs.empty()) { + mcmcOp.emitError("ProbProg: initial RNG state is required as the first " + "function input by convention"); + return failure(); + } + + Value rngState = inputs[0]; + SmallVector fnInputs(inputs.begin() + 1, inputs.end()); + + auto loc = mcmcOp.getLoc(); + auto originalTrace = mcmcOp.getOriginalTrace(); + auto selection = mcmcOp.getSelectionAttr(); + + // 1. Extract initial position vector q0 + int64_t positionSize = + computePositionSizeForSelection(mcmcOp, fn, selection, symbolTable); + if (positionSize <= 0) + return failure(); + + auto positionType = + RankedTensorType::get({positionSize}, rewriter.getF64Type()); + + auto q0 = enzyme::GetFlattenedSamplesFromTraceOp::create( + rewriter, loc, positionType, originalTrace, selection); + + // 2. Compute initial potential energy U0 = -weight + auto weight0 = enzyme::GetWeightFromTraceOp::create( + rewriter, loc, F64TensorType, originalTrace); + Value U0 = conditionalDump(rewriter, loc, + arith::NegFOp::create(rewriter, loc, weight0), + "NUTS: initial potential energy U0"); + + auto zeroConst = arith::ConstantOp::create( + rewriter, loc, F64TensorType, + DenseElementsAttr::get(F64TensorType, rewriter.getF64FloatAttr(0.0))); + auto oneConst = arith::ConstantOp::create( + rewriter, loc, F64TensorType, + DenseElementsAttr::get(F64TensorType, rewriter.getF64FloatAttr(1.0))); + + Value rng1; + Value pInit; + + // 3. Sample initial momentum p0 ~ N(0, M) if M is provided, + // otherwise p0 ~ N(0, I) + Value initialMomentum = mcmcOp.getInitialMomentum(); + if (initialMomentum) { + pInit = initialMomentum; + rng1 = rngState; + } else { + if (invMass) { + auto invMassType = cast(invMass.getType()); + if (invMassType.getRank() == 1) { + // Diagonal case: p0 = 1 / sqrt(invMass) * eps where eps ~ N(0, I) + auto sqrtInvMass = math::SqrtOp::create(rewriter, loc, invMass); + auto onesVector = arith::ConstantOp::create( + rewriter, loc, invMassType, + DenseElementsAttr::get(invMassType, + rewriter.getF64FloatAttr(1.0))); + auto massMatrixSqrt = + arith::DivFOp::create(rewriter, loc, onesVector, sqrtInvMass); + auto randomOp = enzyme::RandomOp::create( + rewriter, loc, TypeRange{rngState.getType(), positionType}, + rngState, zeroConst, oneConst, + enzyme::RngDistributionAttr::get( + rewriter.getContext(), enzyme::RngDistribution::NORMAL)); + rng1 = randomOp.getOutputRngState(); + Value eps = randomOp.getResult(); + pInit = arith::MulFOp::create(rewriter, loc, massMatrixSqrt, eps); + } else { + // Dense case: p0 = mass_matrix_sqrt @ eps where eps ~ N(0, I) + auto identityMatrix = + createIdentityMatrix(rewriter, loc, invMassType); + auto massMatrixSqrt = enzyme::CholeskySolveOp::create( + rewriter, loc, invMassType, invMass, identityMatrix); + auto randomOp = enzyme::RandomOp::create( + rewriter, loc, TypeRange{rngState.getType(), positionType}, + rngState, zeroConst, oneConst, + enzyme::RngDistributionAttr::get( + rewriter.getContext(), enzyme::RngDistribution::NORMAL)); + rng1 = randomOp.getOutputRngState(); + Value eps = randomOp.getResult(); + pInit = enzyme::DotOp::create( + rewriter, loc, positionType, massMatrixSqrt, eps, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/ + rewriter.getDenseI64ArrayAttr({1}), + /*rhs_contracting_dimensions=*/ + rewriter.getDenseI64ArrayAttr({0})); + } + } else { + // Assume identity mass matrix: p0 ~ N(0, I) + auto randomOp = enzyme::RandomOp::create( + rewriter, loc, TypeRange{rngState.getType(), positionType}, + rngState, zeroConst, oneConst, + enzyme::RngDistributionAttr::get( + rewriter.getContext(), enzyme::RngDistribution::NORMAL)); + rng1 = randomOp.getOutputRngState(); + pInit = randomOp.getResult(); + } + } + + auto halfConst = arith::ConstantOp::create( + rewriter, loc, F64TensorType, + DenseElementsAttr::get(F64TensorType, rewriter.getF64FloatAttr(0.5))); + + // 4. Compute initial kinetic energy K0 = 0.5 * p^T * M^-1 * p + Value K0; + if (invMass) { + auto invMassType = cast(invMass.getType()); + Value invMassP0; + + if (invMassType.getRank() == 1) { + invMassP0 = arith::MulFOp::create(rewriter, loc, invMass, pInit); + } else { + invMassP0 = enzyme::DotOp::create( + rewriter, loc, positionType, invMass, pInit, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({1}), + /*rhs_contracting_dimensions=*/ + rewriter.getDenseI64ArrayAttr({0})); + } + + auto p0DotInvMassP0 = enzyme::DotOp::create( + rewriter, loc, F64TensorType, pInit, invMassP0, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}), + /*rhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0})); + K0 = conditionalDump( + rewriter, loc, + arith::MulFOp::create(rewriter, loc, halfConst, p0DotInvMassP0), + "NUTS: initial kinetic energy K0"); + } else { + auto p0DotP0 = enzyme::DotOp::create( + rewriter, loc, F64TensorType, pInit, pInit, + /*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}), + /*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}), + /*rhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0})); + K0 = conditionalDump( + rewriter, loc, + arith::MulFOp::create(rewriter, loc, halfConst, p0DotP0), + "NUTS: initial kinetic energy K0"); + } + + Value H0 = conditionalDump(rewriter, loc, + arith::AddFOp::create(rewriter, loc, U0, K0), + "NUTS: initial Hamiltonian H0"); + + // 5. Compute initial gradient at q0 + auto gradSeedInit = arith::ConstantOp::create( + rewriter, loc, F64TensorType, + DenseElementsAttr::get(F64TensorType, rewriter.getF64FloatAttr(1.0))); + auto autodiffInit = enzyme::AutoDiffRegionOp::create( + rewriter, loc, TypeRange{rng1.getType(), positionType}, + ValueRange{q0, gradSeedInit}, + rewriter.getArrayAttr({enzyme::ActivityAttr::get( + rewriter.getContext(), enzyme::Activity::enzyme_active)}), + rewriter.getArrayAttr( + {enzyme::ActivityAttr::get( + rewriter.getContext(), + enzyme::Activity::enzyme_activenoneed), // U0 not needed here + enzyme::ActivityAttr::get(rewriter.getContext(), + enzyme::Activity::enzyme_const)}), + rewriter.getI64IntegerAttr(1), rewriter.getBoolAttr(false), nullptr); + + Block *autodiffInitBlock = rewriter.createBlock(&autodiffInit.getBody()); + autodiffInitBlock->addArgument(positionType, loc); + + rewriter.setInsertionPointToStart(autodiffInitBlock); + Value q0Arg = autodiffInitBlock->getArgument(0); + + SmallVector updateInputsInit; + updateInputsInit.push_back(rng1); + updateInputsInit.append(fnInputs.begin(), fnInputs.end()); + + auto updateOpInit = enzyme::UpdateOp::create( + rewriter, loc, TypeRange{traceType, F64TensorType, rng1.getType()}, + mcmcOp.getFnAttr(), updateInputsInit, originalTrace, q0Arg, selection, + rewriter.getStringAttr("")); + Value w0 = updateOpInit.getWeight(); + Value rng0_out = updateOpInit.getOutputRngState(); + Value U0_init = arith::NegFOp::create(rewriter, loc, w0); + + enzyme::YieldOp::create(rewriter, loc, ValueRange{U0_init, rng0_out}); + + rewriter.setInsertionPointAfter(autodiffInit); + Value rng0_final = autodiffInit.getResult(0); + Value grad0 = autodiffInit.getResult(1); + + // 6. Set up NUTS doubling loop (outer) + auto i1TensorType = RankedTensorType::get({}, rewriter.getI1Type()); + auto i64TensorType = RankedTensorType::get({}, rewriter.getI64Type()); + + auto zeroI64 = arith::ConstantOp::create( + rewriter, loc, i64TensorType, + DenseElementsAttr::get(i64TensorType, rewriter.getI64IntegerAttr(0))); + auto oneI64 = arith::ConstantOp::create( + rewriter, loc, i64TensorType, + DenseElementsAttr::get(i64TensorType, rewriter.getI64IntegerAttr(1))); + auto falseConst = arith::ConstantOp::create( + rewriter, loc, i1TensorType, + DenseElementsAttr::get(i1TensorType, rewriter.getBoolAttr(false))); + auto trueConst = arith::ConstantOp::create( + rewriter, loc, i1TensorType, + DenseElementsAttr::get(i1TensorType, rewriter.getBoolAttr(true))); + auto zeroWeight = arith::ConstantOp::create( + rewriter, loc, F64TensorType, + DenseElementsAttr::get(F64TensorType, rewriter.getF64FloatAttr(0.0))); + + NUTSTree initialTree = {.q_left = q0, + .p_left = pInit, + .grad_left = grad0, + .q_right = q0, + .p_right = pInit, + .grad_right = grad0, + .q_proposal = q0, + .grad_proposal = grad0, + .U_proposal = U0, + .H_proposal = H0, + .depth = zeroI64, + .weight = zeroWeight, + .turning = falseConst, + .diverging = falseConst, + .sum_accept_probs = oneConst, + .num_proposals = oneI64, + .p_sum = pInit}; + + auto maxTreeDepth = arith::ConstantOp::create( + rewriter, loc, i64TensorType, + DenseElementsAttr::get( + i64TensorType, + rewriter.getI64IntegerAttr(10))); // TODO: Make adjustable + + auto maxDeltaEnergy = arith::ConstantOp::create( + rewriter, loc, F64TensorType, + DenseElementsAttr::get( + F64TensorType, + rewriter.getF64FloatAttr(1000.0))); // TODO: Make adjustable + + SmallVector whileLoopTypes = initialTree.getTypes(); + whileLoopTypes.push_back(rng0_final.getType()); + SmallVector whileLoopInitVals = initialTree.toValues(); + whileLoopInitVals.push_back(rng0_final); + + auto outerWhileOp = enzyme::WhileLoopOp::create( + rewriter, loc, whileLoopTypes, whileLoopInitVals); + + Block *outerCondBlock = + rewriter.createBlock(&outerWhileOp.getConditionRegion()); + for (auto type : whileLoopTypes) + outerCondBlock->addArgument(type, loc); + + rewriter.setInsertionPointToStart(outerCondBlock); + + SmallVector treeArgs(outerCondBlock->getArguments().begin(), + outerCondBlock->getArguments().begin() + + NUTSTree::NUM_FIELDS); + NUTSTree treeCond = NUTSTree::fromValues(treeArgs); + Value rngCond = outerCondBlock->getArgument(NUTSTree::NUM_FIELDS); + + // Condition 6a: depth < maxTreeDepth + auto notMaxDepth = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + treeCond.depth, maxTreeDepth); + + // Condition 6b: NOT turning + auto notTurning = + arith::XOrIOp::create(rewriter, loc, treeCond.turning, trueConst); + + // Condition 6c: NOT diverging + auto notDiverging = + arith::XOrIOp::create(rewriter, loc, treeCond.diverging, trueConst); + + auto continueDoublingCond = arith::AndIOp::create( + rewriter, loc, + arith::AndIOp::create(rewriter, loc, notMaxDepth.getResult(), + notTurning.getResult()), + notDiverging.getResult()); + + enzyme::YieldOp::create(rewriter, loc, + ValueRange{continueDoublingCond.getResult()}); + + Block *outerBodyBlock = + rewriter.createBlock(&outerWhileOp.getBodyRegion()); + for (auto type : whileLoopTypes) + outerBodyBlock->addArgument(type, loc); + + rewriter.setInsertionPointToStart(outerBodyBlock); + + SmallVector treeBodyArgs(outerBodyBlock->getArguments().begin(), + outerBodyBlock->getArguments().begin() + + NUTSTree::NUM_FIELDS); + NUTSTree treeBody = NUTSTree::fromValues(treeBodyArgs); + Value rngBody = outerBodyBlock->getArgument(NUTSTree::NUM_FIELDS); + + // Body 6a: Sample direction (left or right) + auto rngSplitOp = enzyme::RandomSplitOp::create( + rewriter, loc, TypeRange{rngBody.getType(), rngBody.getType()}, + rngBody); + Value rngDir = rngSplitOp.getResult(0); + Value rngDoubling = rngSplitOp.getResult(1); + + auto randomDir = enzyme::RandomOp::create( + rewriter, loc, TypeRange{rngDir.getType(), F64TensorType}, rngDir, + zeroConst, oneConst, + enzyme::RngDistributionAttr::get(rewriter.getContext(), + enzyme::RngDistribution::UNIFORM)); + Value rngDir_out = randomDir.getOutputRngState(); + auto goingRight = + arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT, + randomDir.getResult(), halfConst); + + // Body 6b: Build subtree with 2^(currentDepth + 1) proposals. + Value currentDepth = treeBody.depth; + auto depthForSubtree = + arith::AddIOp::create(rewriter, loc, currentDepth, oneI64); + + // 7. Subtree building. + SmallVector innerWhileTypes = treeBody.getTypes(); + innerWhileTypes.push_back(rngDoubling.getType()); + + SmallVector innerWhileInitVals = treeBody.toValues(); + innerWhileInitVals.push_back(rngDoubling); + + auto innerWhileOp = enzyme::WhileLoopOp::create( + rewriter, loc, innerWhileTypes, innerWhileInitVals); + + Block *innerCondBlock = + rewriter.createBlock(&innerWhileOp.getConditionRegion()); + for (auto type : innerWhileTypes) + innerCondBlock->addArgument(type, loc); + + rewriter.setInsertionPointToStart(innerCondBlock); + + SmallVector subtreeCondArgs( + innerCondBlock->getArguments().begin(), + innerCondBlock->getArguments().begin() + NUTSTree::NUM_FIELDS); + NUTSTree subtreeCond = NUTSTree::fromValues(subtreeCondArgs); + Value rngInnerCond = innerCondBlock->getArgument(NUTSTree::NUM_FIELDS); + + // Condition 7a: depth < depthForSubtree + auto subtreeDepthOk = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + subtreeCond.depth, depthForSubtree); + + // Condition 7b: NOT turning + auto subtreeNotTurning = + arith::XOrIOp::create(rewriter, loc, subtreeCond.turning, trueConst); + + // Condition 7c: NOT diverging + auto subtreeNotDiverging = arith::XOrIOp::create( + rewriter, loc, subtreeCond.diverging, trueConst); + + auto continueSubtreeCond = arith::AndIOp::create( + rewriter, loc, + arith::AndIOp::create(rewriter, loc, subtreeDepthOk.getResult(), + subtreeNotTurning.getResult()), + subtreeNotDiverging.getResult()); + + enzyme::YieldOp::create(rewriter, loc, + ValueRange{continueSubtreeCond.getResult()}); + + // Body 7a: Set up subtree building loop. + Block *innerBodyBlock = + rewriter.createBlock(&innerWhileOp.getBodyRegion()); + for (auto type : innerWhileTypes) + innerBodyBlock->addArgument(type, loc); + + rewriter.setInsertionPointToStart(innerBodyBlock); + + SmallVector subtreeIterArgs( + innerBodyBlock->getArguments().begin(), + innerBodyBlock->getArguments().begin() + NUTSTree::NUM_FIELDS); + NUTSTree subtreeIter = NUTSTree::fromValues(subtreeIterArgs); + Value rngIter = innerBodyBlock->getArgument(NUTSTree::NUM_FIELDS); + + // Body 7a: Extract boundary from subtree based on direction + auto goingRightBroadcast = enzyme::BroadcastOp::create( + rewriter, loc, + RankedTensorType::get(positionType.getShape(), rewriter.getI1Type()), + goingRight, rewriter.getDenseI64ArrayAttr(positionType.getShape())); + + Value leafQ = arith::SelectOp::create( + rewriter, loc, positionType, goingRightBroadcast, subtreeIter.q_right, + subtreeIter.q_left); + Value leafP = arith::SelectOp::create( + rewriter, loc, positionType, goingRightBroadcast, subtreeIter.p_right, + subtreeIter.p_left); + Value leafGrad = arith::SelectOp::create( + rewriter, loc, positionType, goingRightBroadcast, + subtreeIter.grad_right, subtreeIter.grad_left); + + // Body 7b: Prepare RNG states and adjust step size based on direction. + auto rngSplit3 = enzyme::RandomSplitOp::create( + rewriter, loc, + TypeRange{rngIter.getType(), rngIter.getType(), rngIter.getType()}, + rngIter); + Value rngLeaf = rngSplit3.getResult(0); + Value rngCombine = rngSplit3.getResult(1); + Value rngForNext = rngSplit3.getResult(2); + + auto negStepSize = arith::NegFOp::create(rewriter, loc, stepSize); + Value eps = arith::SelectOp::create(rewriter, loc, F64TensorType, + goingRight, stepSize, negStepSize); + + ArrayRef positionShape = positionType.getShape(); + auto stepSizeBroadcast = enzyme::BroadcastOp::create( + rewriter, loc, positionType, eps, + rewriter.getDenseI64ArrayAttr(positionShape)); + auto halfStep = arith::MulFOp::create(rewriter, loc, halfConst, eps); + auto halfStepBroadcast = enzyme::BroadcastOp::create( + rewriter, loc, positionType, halfStep, + rewriter.getDenseI64ArrayAttr(positionShape)); + + // Body 7c: Leapfrog integration. + + // Half step momentum: p_half = p - 0.5 * eps * gradient + auto deltaP1 = + arith::MulFOp::create(rewriter, loc, halfStepBroadcast, leafGrad); + Value pHalf = arith::SubFOp::create(rewriter, loc, leafP, deltaP1); + + // Full step position: q_new = q + eps * M^-1 * p_half + Value v; + if (invMass) { + auto invMassType = cast(invMass.getType()); + if (invMassType.getRank() == 1) { + v = arith::MulFOp::create(rewriter, loc, invMass, pHalf); + } else if (invMassType.getRank() == 2) { + v = enzyme::DotOp::create(rewriter, loc, positionType, invMass, pHalf, + rewriter.getDenseI64ArrayAttr({}), + rewriter.getDenseI64ArrayAttr({}), + rewriter.getDenseI64ArrayAttr({1}), + rewriter.getDenseI64ArrayAttr({0})); + } else { + mcmcOp.emitError("ProbProg: Unsupported rank for invMass"); + return failure(); + } + } else { + v = pHalf; + } + + auto deltaQ = arith::MulFOp::create(rewriter, loc, stepSizeBroadcast, v); + Value qNew = arith::AddFOp::create(rewriter, loc, leafQ, deltaQ); + + // Compute potential energy and gradient at new position `qNew`. + auto gradSeed = arith::ConstantOp::create( + rewriter, loc, F64TensorType, + DenseElementsAttr::get(F64TensorType, rewriter.getF64FloatAttr(1.0))); + // We do need the NLL (a.k.a. potential energy) here, so `enzyme_active` + // on the NLL. + auto autodiffOp = enzyme::AutoDiffRegionOp::create( + rewriter, loc, + TypeRange{F64TensorType, rngLeaf.getType(), positionType}, + ValueRange{qNew, gradSeed}, + rewriter.getArrayAttr({enzyme::ActivityAttr::get( + rewriter.getContext(), enzyme::Activity::enzyme_active)}), + rewriter.getArrayAttr( + {enzyme::ActivityAttr::get(rewriter.getContext(), + enzyme::Activity::enzyme_active), + enzyme::ActivityAttr::get(rewriter.getContext(), + enzyme::Activity::enzyme_const)}), + rewriter.getI64IntegerAttr(1), rewriter.getBoolAttr(false), nullptr); + + Block *autodiffBlock = rewriter.createBlock(&autodiffOp.getBody()); + autodiffBlock->addArgument(positionType, loc); + + rewriter.setInsertionPointToStart(autodiffBlock); + Value qNewArg = autodiffBlock->getArgument(0); + + SmallVector updateInputs; + updateInputs.push_back(rngLeaf); + updateInputs.append(fnInputs.begin(), fnInputs.end()); + + auto updateResult = enzyme::UpdateOp::create( + rewriter, loc, TypeRange{traceType, F64TensorType, rngLeaf.getType()}, + mcmcOp.getFnAttr(), updateInputs, originalTrace, qNewArg, selection, + rewriter.getStringAttr("")); + + Value todiff = + arith::NegFOp::create(rewriter, loc, updateResult.getWeight()); + + enzyme::YieldOp::create( + rewriter, loc, ValueRange{todiff, updateResult.getOutputRngState()}); + + rewriter.setInsertionPointAfter(autodiffOp); + + // AutodiffRegionOp returns: (UNew, RNG, dUNew/dqNew) + Value UNew = autodiffOp.getResult(0); + Value gradNew = autodiffOp.getResult(2); + + // Half step momentum: p_new = p_half - 0.5 * eps * grad_new + auto deltaP2 = + arith::MulFOp::create(rewriter, loc, halfStepBroadcast, gradNew); + Value pNew = arith::SubFOp::create(rewriter, loc, pHalf, deltaP2); + + // Body 7d: Compute kinetic energy. + Value invMassPNew; + if (invMass) { + auto invMassType = cast(invMass.getType()); + if (invMassType.getRank() == 1) { + invMassPNew = arith::MulFOp::create(rewriter, loc, invMass, pNew); + } else { + invMassPNew = + enzyme::DotOp::create(rewriter, loc, positionType, invMass, pNew, + rewriter.getDenseI64ArrayAttr({}), + rewriter.getDenseI64ArrayAttr({}), + rewriter.getDenseI64ArrayAttr({1}), + rewriter.getDenseI64ArrayAttr({0})); + } + } else { + invMassPNew = pNew; + } + + auto pDotInvMassP = enzyme::DotOp::create( + rewriter, loc, F64TensorType, pNew, invMassPNew, + rewriter.getDenseI64ArrayAttr({}), rewriter.getDenseI64ArrayAttr({}), + rewriter.getDenseI64ArrayAttr({0}), + rewriter.getDenseI64ArrayAttr({0})); + Value KNew = + arith::MulFOp::create(rewriter, loc, halfConst, pDotInvMassP); + Value ENew = arith::AddFOp::create(rewriter, loc, UNew, KNew); + + // Body 7e: Various checks. + auto deltaE = arith::SubFOp::create(rewriter, loc, ENew, H0); + + Value leafDiverging = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OGT, deltaE, maxDeltaEnergy); + auto treeWeight = arith::NegFOp::create(rewriter, loc, deltaE); + + // Body 7f: Compute acceptance probability. + auto acceptProbRaw = math::ExpOp::create(rewriter, loc, treeWeight); + auto acceptProb = + arith::MinimumFOp::create(rewriter, loc, acceptProbRaw, oneConst); + + // Body 7g: Create leaf tree state. + NUTSTree newLeaf = {.q_left = qNew, + .p_left = pNew, + .grad_left = gradNew, + .q_right = qNew, + .p_right = pNew, + .grad_right = gradNew, + .q_proposal = qNew, + .grad_proposal = gradNew, + .U_proposal = UNew, + .H_proposal = ENew, + .depth = zeroI64, + .weight = treeWeight, + .turning = falseConst, + .diverging = leafDiverging, + .sum_accept_probs = acceptProb, + .num_proposals = oneI64, + .p_sum = pNew}; + + // Body 7h: Combine new leaf with the current subtree. + // 7h.1: Update boundaries based on direction. + Value qLeft = arith::SelectOp::create(rewriter, loc, positionType, + goingRightBroadcast, + subtreeIter.q_left, newLeaf.q_left); + Value pLeft = arith::SelectOp::create(rewriter, loc, positionType, + goingRightBroadcast, + subtreeIter.p_left, newLeaf.p_left); + Value gradLeft = arith::SelectOp::create( + rewriter, loc, positionType, goingRightBroadcast, + subtreeIter.grad_left, newLeaf.grad_left); + + Value qRight = arith::SelectOp::create( + rewriter, loc, positionType, goingRightBroadcast, newLeaf.q_right, + subtreeIter.q_right); + Value pRight = arith::SelectOp::create( + rewriter, loc, positionType, goingRightBroadcast, newLeaf.p_right, + subtreeIter.p_right); + Value gradRight = arith::SelectOp::create( + rewriter, loc, positionType, goingRightBroadcast, newLeaf.grad_right, + subtreeIter.grad_right); + + // 7h.2: Combine weights using log_add_exp. + Value combinedWeight = enzyme::LogAddExpOp::create( + rewriter, loc, F64TensorType, subtreeIter.weight, newLeaf.weight); + + Value weightDiffCombine = arith::SubFOp::create( + rewriter, loc, newLeaf.weight, subtreeIter.weight); + Value acceptProbCombine = createSigmoid(rewriter, loc, weightDiffCombine); + + // 7h.3: Select proposal with multinomial sampling. + auto randomOpCombine = enzyme::RandomOp::create( + rewriter, loc, TypeRange{rngCombine.getType(), F64TensorType}, + rngCombine, zeroConst, oneConst, + enzyme::RngDistributionAttr::get(rewriter.getContext(), + enzyme::RngDistribution::UNIFORM)); + Value rngAfterCombine = randomOpCombine.getOutputRngState(); + Value uniformSampleCombine = randomOpCombine.getResult(); + + auto acceptNew = + arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OLT, + uniformSampleCombine, acceptProbCombine); + + auto acceptNewBroadcast = enzyme::BroadcastOp::create( + rewriter, loc, + RankedTensorType::get(positionType.getShape(), rewriter.getI1Type()), + acceptNew, rewriter.getDenseI64ArrayAttr(positionType.getShape())); + + Value qProposal = arith::SelectOp::create( + rewriter, loc, positionType, acceptNewBroadcast, newLeaf.q_proposal, + subtreeIter.q_proposal); + Value gradProposal = arith::SelectOp::create( + rewriter, loc, positionType, acceptNewBroadcast, + newLeaf.grad_proposal, subtreeIter.grad_proposal); + + Value UProposal = + arith::SelectOp::create(rewriter, loc, F64TensorType, acceptNew, + newLeaf.U_proposal, subtreeIter.U_proposal); + Value EProposal = + arith::SelectOp::create(rewriter, loc, F64TensorType, acceptNew, + newLeaf.H_proposal, subtreeIter.H_proposal); + + // 7h.4: Update metadata. + Value combinedDepth = + arith::AddIOp::create(rewriter, loc, subtreeIter.depth, oneI64); + Value combinedTurning = arith::OrIOp::create( + rewriter, loc, subtreeIter.turning, newLeaf.turning); + Value combinedDiverging = arith::OrIOp::create( + rewriter, loc, subtreeIter.diverging, newLeaf.diverging); + Value sumAcceptProbs = + arith::AddFOp::create(rewriter, loc, subtreeIter.sum_accept_probs, + newLeaf.sum_accept_probs); + Value numProposals = arith::AddIOp::create( + rewriter, loc, subtreeIter.num_proposals, newLeaf.num_proposals); + Value pSum = arith::AddFOp::create(rewriter, loc, subtreeIter.p_sum, + newLeaf.p_sum); + NUTSTree updatedSubtree = {.q_left = qLeft, + .p_left = pLeft, + .grad_left = gradLeft, + .q_right = qRight, + .p_right = pRight, + .grad_right = gradRight, + .q_proposal = qProposal, + .grad_proposal = gradProposal, + .U_proposal = UProposal, + .H_proposal = EProposal, + .depth = combinedDepth, + .weight = combinedWeight, + .turning = combinedTurning, + .diverging = combinedDiverging, + .sum_accept_probs = sumAcceptProbs, + .num_proposals = numProposals, + .p_sum = pSum}; + + // Body 7i: Check and update turning flag. + // Turning criterion: (p_left + p_right) · M^-1 · p_left >= 0 + // AND (p_left + p_right) · M^-1 · p_right >= 0 + Value invMassPLeft, invMassPRight; + if (invMass) { + auto invMassType = cast(invMass.getType()); + if (invMassType.getRank() == 1) { + invMassPLeft = arith::MulFOp::create(rewriter, loc, invMass, + updatedSubtree.p_left); + invMassPRight = arith::MulFOp::create(rewriter, loc, invMass, + updatedSubtree.p_right); + } else { + invMassPLeft = enzyme::DotOp::create( + rewriter, loc, positionType, invMass, updatedSubtree.p_left, + rewriter.getDenseI64ArrayAttr({}), + rewriter.getDenseI64ArrayAttr({}), + rewriter.getDenseI64ArrayAttr({1}), + rewriter.getDenseI64ArrayAttr({0})); + invMassPRight = enzyme::DotOp::create( + rewriter, loc, positionType, invMass, updatedSubtree.p_right, + rewriter.getDenseI64ArrayAttr({}), + rewriter.getDenseI64ArrayAttr({}), + rewriter.getDenseI64ArrayAttr({1}), + rewriter.getDenseI64ArrayAttr({0})); + } + } else { + invMassPLeft = updatedSubtree.p_left; + invMassPRight = updatedSubtree.p_right; + } + + auto dotLeft = enzyme::DotOp::create( + rewriter, loc, F64TensorType, updatedSubtree.p_sum, invMassPLeft, + rewriter.getDenseI64ArrayAttr({}), rewriter.getDenseI64ArrayAttr({}), + rewriter.getDenseI64ArrayAttr({0}), + rewriter.getDenseI64ArrayAttr({0})); + auto dotRight = enzyme::DotOp::create( + rewriter, loc, F64TensorType, updatedSubtree.p_sum, invMassPRight, + rewriter.getDenseI64ArrayAttr({}), rewriter.getDenseI64ArrayAttr({}), + rewriter.getDenseI64ArrayAttr({0}), + rewriter.getDenseI64ArrayAttr({0})); + auto leftNegative = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OLT, dotLeft, zeroConst); + auto rightNegative = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OLT, dotRight, zeroConst); + + auto turning = arith::OrIOp::create( + rewriter, loc, leftNegative.getResult(), rightNegative.getResult()); + + updatedSubtree.turning = turning.getResult(); + + SmallVector yieldVals = updatedSubtree.toValues(); + yieldVals.push_back(rngIter); + enzyme::YieldOp::create(rewriter, loc, yieldVals); + + // 8. Combine subtree with main tree. + rewriter.setInsertionPointAfter(innerWhileOp); + SmallVector subtreeValues(innerWhileOp.getResults().begin(), + innerWhileOp.getResults().begin() + + NUTSTree::NUM_FIELDS); + NUTSTree subtree = NUTSTree::fromValues(subtreeValues); + Value rngAfterBuild = innerWhileOp.getResult(NUTSTree::NUM_FIELDS); + + auto rngSplitAfterBuild = enzyme::RandomSplitOp::create( + rewriter, loc, + TypeRange{rngAfterBuild.getType(), rngAfterBuild.getType()}, + rngAfterBuild); + Value rngTrans = rngSplitAfterBuild.getResult(0); + Value rngNext = rngSplitAfterBuild.getResult(1); + + // 8a. Update boundaries based on direction. + auto goingRightMainBroadcast = enzyme::BroadcastOp::create( + rewriter, loc, + RankedTensorType::get(positionType.getShape(), rewriter.getI1Type()), + goingRight, rewriter.getDenseI64ArrayAttr(positionType.getShape())); + + Value qLeftMain = arith::SelectOp::create( + rewriter, loc, positionType, goingRightMainBroadcast, treeBody.q_left, + subtree.q_left); + Value pLeftMain = arith::SelectOp::create( + rewriter, loc, positionType, goingRightMainBroadcast, treeBody.p_left, + subtree.p_left); + Value gradLeftMain = arith::SelectOp::create( + rewriter, loc, positionType, goingRightMainBroadcast, + treeBody.grad_left, subtree.grad_left); + Value qRightMain = arith::SelectOp::create( + rewriter, loc, positionType, goingRightMainBroadcast, subtree.q_right, + treeBody.q_right); + Value pRightMain = arith::SelectOp::create( + rewriter, loc, positionType, goingRightMainBroadcast, subtree.p_right, + treeBody.p_right); + Value gradRightMain = arith::SelectOp::create( + rewriter, loc, positionType, goingRightMainBroadcast, + subtree.grad_right, treeBody.grad_right); + + // 8b. Combine weights using log_add_exp. + Value combinedWeightMain = enzyme::LogAddExpOp::create( + rewriter, loc, F64TensorType, treeBody.weight, subtree.weight); + + // 8c. Proposal selection via multinomial sampling. + Value weightDiffMain = + arith::SubFOp::create(rewriter, loc, subtree.weight, treeBody.weight); + Value acceptProbMainRaw = createSigmoid(rewriter, loc, weightDiffMain); + + // 8d. Zero accept probability to 0 if new tree is turning or diverging. + Value acceptProbMain = arith::SelectOp::create( + rewriter, loc, F64TensorType, + arith::OrIOp::create(rewriter, loc, subtree.turning, + subtree.diverging), + zeroConst, acceptProbMainRaw); + + // 8e. Compute acceptance probability on new proposal. + auto randomOpMain = enzyme::RandomOp::create( + rewriter, loc, TypeRange{rngTrans.getType(), F64TensorType}, rngTrans, + zeroConst, oneConst, + enzyme::RngDistributionAttr::get(rewriter.getContext(), + enzyme::RngDistribution::UNIFORM)); + Value rngAfterCombineFinal = randomOpMain.getOutputRngState(); + Value uniformSampleMain = randomOpMain.getResult(); + + auto acceptNewMain = + arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OLT, + uniformSampleMain, acceptProbMain); + + // 8f. Select proposal components. + auto acceptNewMainBroadcast = enzyme::BroadcastOp::create( + rewriter, loc, + RankedTensorType::get(positionType.getShape(), rewriter.getI1Type()), + acceptNewMain, + rewriter.getDenseI64ArrayAttr(positionType.getShape())); + + Value qProposalMain = arith::SelectOp::create( + rewriter, loc, positionType, acceptNewMainBroadcast, + subtree.q_proposal, treeBody.q_proposal); + Value gradProposalMain = arith::SelectOp::create( + rewriter, loc, positionType, acceptNewMainBroadcast, + subtree.grad_proposal, treeBody.grad_proposal); + + Value UProposalMain = + arith::SelectOp::create(rewriter, loc, F64TensorType, acceptNewMain, + subtree.U_proposal, treeBody.U_proposal); + Value EProposalMain = + arith::SelectOp::create(rewriter, loc, F64TensorType, acceptNewMain, + subtree.H_proposal, treeBody.H_proposal); + Value combinedDepthMain = + arith::AddIOp::create(rewriter, loc, treeBody.depth, oneI64); + Value combinedTurningMain = arith::OrIOp::create( + rewriter, loc, treeBody.turning, subtree.turning); + Value combinedDivergingMain = arith::OrIOp::create( + rewriter, loc, treeBody.diverging, subtree.diverging); + Value sumAcceptProbsMain = arith::AddFOp::create( + rewriter, loc, treeBody.sum_accept_probs, subtree.sum_accept_probs); + Value numProposalsMain = arith::AddIOp::create( + rewriter, loc, treeBody.num_proposals, subtree.num_proposals); + Value pSumMain = + arith::AddFOp::create(rewriter, loc, treeBody.p_sum, subtree.p_sum); + + NUTSTree combinedTree = {.q_left = qLeftMain, + .p_left = pLeftMain, + .grad_left = gradLeftMain, + .q_right = qRightMain, + .p_right = pRightMain, + .grad_right = gradRightMain, + .q_proposal = qProposalMain, + .grad_proposal = gradProposalMain, + .U_proposal = UProposalMain, + .H_proposal = EProposalMain, + .depth = combinedDepthMain, + .weight = combinedWeightMain, + .turning = combinedTurningMain, + .diverging = combinedDivergingMain, + .sum_accept_probs = sumAcceptProbsMain, + .num_proposals = numProposalsMain, + .p_sum = pSumMain}; + + // 8g. Yield combined tree. + SmallVector outerYieldVals = combinedTree.toValues(); + outerYieldVals.push_back(rngNext); + enzyme::YieldOp::create(rewriter, loc, outerYieldVals); + + rewriter.setInsertionPointAfter(outerWhileOp); + + // 9. Extract final proposal from combined tree. + SmallVector finalTreeValues(outerWhileOp.getResults().begin(), + outerWhileOp.getResults().begin() + + NUTSTree::NUM_FIELDS); + NUTSTree finalTree = NUTSTree::fromValues(finalTreeValues); + Value rngFinal = outerWhileOp.getResult(NUTSTree::NUM_FIELDS); + + Value qFinal = finalTree.q_proposal; + + // 10. Generate final trace at proposed position. + SmallVector finalUpdateInputs; + finalUpdateInputs.push_back(rngFinal); + finalUpdateInputs.append(fnInputs.begin(), fnInputs.end()); + + auto finalUpdateOp = enzyme::UpdateOp::create( + rewriter, loc, + TypeRange{traceType, F64TensorType, rngFinal.getType()}, + mcmcOp.getFnAttr(), finalUpdateInputs, originalTrace, qFinal, + selection, rewriter.getStringAttr("")); + Value finalTrace = finalUpdateOp.getUpdatedTrace(); + Value rngAfterUpdate = finalUpdateOp.getOutputRngState(); + + Value UFinal = conditionalDump( + rewriter, loc, + arith::NegFOp::create(rewriter, loc, finalUpdateOp.getWeight()), + "NUTS: final potential energy UFinal"); + + // 11. Metropolis-Hastings accept/reject step. + Value HFinal = conditionalDump(rewriter, loc, finalTree.H_proposal, + "NUTS: final Hamiltonian HFinal"); + + auto dH = arith::SubFOp::create(rewriter, loc, H0, HFinal); + auto accProbRaw = math::ExpOp::create(rewriter, loc, dH); + Value accProb = conditionalDump( + rewriter, loc, + arith::MinimumFOp::create(rewriter, loc, oneConst, accProbRaw), + "NUTS: acceptance probability α"); + + auto randomOp2 = enzyme::RandomOp::create( + rewriter, loc, TypeRange{rngAfterUpdate.getType(), F64TensorType}, + rngAfterUpdate, zeroConst, oneConst, + enzyme::RngDistributionAttr::get(rewriter.getContext(), + enzyme::RngDistribution::UNIFORM)); + Value rngFinalMH = randomOp2.getOutputRngState(); + Value randUniform = randomOp2.getResult(); + + auto acceptedTensor = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OLT, randUniform, accProb); + + // 12. Select trace based on acceptance + auto selectedTrace = enzyme::SelectTraceOp::create( + rewriter, loc, traceType, acceptedTensor, finalTrace, originalTrace); + + rewriter.replaceOp(mcmcOp, {selectedTrace, acceptedTensor, rngFinalMH}); + + return success(); + } }; struct LowerMHPattern : public mlir::OpRewritePattern { @@ -1728,10 +2811,10 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase { void ProbProgPass::runOnOperation() { RewritePatternSet patterns(&getContext()); patterns - .add( + .add( &getContext()); - patterns.add(&getContext(), debugMCMC); + patterns.add(&getContext(), debugDump); mlir::GreedyRewriteConfig config; diff --git a/enzyme/test/MLIR/ProbProg/hmc.mlir b/enzyme/test/MLIR/ProbProg/hmc.mlir index 2803554e507..9939d7d4531 100644 --- a/enzyme/test/MLIR/ProbProg/hmc.mlir +++ b/enzyme/test/MLIR/ProbProg/hmc.mlir @@ -14,12 +14,12 @@ module { %unused = enzyme.initTrace : !enzyme.Trace %init_trace = enzyme.initTrace : !enzyme.Trace - %mass = arith.constant dense<[[1.0, 0.0], [0.0, 1.0]]> : tensor<2x2xf64> + %inverse_mass_matrix = arith.constant dense<[[1.0, 0.0], [0.0, 1.0]]> : tensor<2x2xf64> %step_size = arith.constant dense<0.1> : tensor %num_steps = arith.constant dense<10> : tensor %res:3 = enzyme.mcmc algorithm = HMC @test(%rng, %mean, %stddev) given %init_trace - mass = %mass : tensor<2x2xf64> + inverse_mass_matrix = %inverse_mass_matrix : tensor<2x2xf64> step_size = %step_size : tensor num_steps = %num_steps : tensor { name = "hmc", selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]] } : (tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) @@ -41,49 +41,51 @@ module { // CHECK-NEXT: %1 = enzyme.getFlattenedSamplesFromTrace %0 {selection = {{\[}}[#enzyme.symbol<1>], [#enzyme.symbol<2>]{{\]}}} : tensor<2xf64> // CHECK-NEXT: %2 = enzyme.getWeightFromTrace %0 : tensor // CHECK-NEXT: %3 = arith.negf %2 : tensor -// CHECK-NEXT: %output_rng_state, %result = enzyme.random %arg0, %cst_4, %cst_7 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor<2x2xf64>) -> (tensor<2xui64>, tensor<2xf64>) -// CHECK-NEXT: %4 = enzyme.cholesky_solve %cst_7, %result : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> -// CHECK-NEXT: %5 = enzyme.dot %result, %4 : (tensor<2xf64>, tensor<2xf64>) -> tensor -// CHECK-NEXT: %6 = arith.mulf %5, %cst_2 : tensor -// CHECK-NEXT: %7 = arith.addf %3, %6 : tensor -// CHECK-NEXT: %8:2 = enzyme.autodiff_region(%1, %cst_3) { +// CHECK-NEXT: %4 = enzyme.cholesky_solve %cst_7, %cst_7 : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64> +// CHECK-NEXT: %output_rng_state, %result = enzyme.random %arg0, %cst_4, %cst_3 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<2xf64>) +// CHECK-NEXT: %5 = enzyme.dot %4, %result {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %6 = enzyme.dot %cst_7, %5 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %7 = enzyme.dot %5, %6 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor +// CHECK-NEXT: %8 = arith.mulf %7, %cst_2 : tensor +// CHECK-NEXT: %9 = arith.addf %3, %8 : tensor +// CHECK-NEXT: %10:2 = enzyme.autodiff_region(%1, %cst_3) { // CHECK-NEXT: ^bb0(%arg3: tensor<2xf64>): -// CHECK-NEXT: %23:3 = func.call @test.update_1(%0, %arg3, %output_rng_state, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) -// CHECK-NEXT: %24 = arith.negf %23#1 : tensor -// CHECK-NEXT: enzyme.yield %24, %23#2 : tensor, tensor<2xui64> +// CHECK-NEXT: %25:3 = func.call @test.update_1(%0, %arg3, %output_rng_state, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) +// CHECK-NEXT: %26 = arith.negf %25#1 : tensor +// CHECK-NEXT: enzyme.yield %26, %25#2 : tensor, tensor<2xui64> // CHECK-NEXT: } attributes {activity = [#enzyme], ret_activity = [#enzyme, #enzyme]} : (tensor<2xf64>, tensor) -> (tensor<2xui64>, tensor<2xf64>) -// CHECK-NEXT: %9 = "enzyme.broadcast"(%cst_6) <{shape = array}> : (tensor) -> tensor<2xf64> -// CHECK-NEXT: %10 = "enzyme.broadcast"(%cst) <{shape = array}> : (tensor) -> tensor<2xf64> -// CHECK-NEXT: %11:4 = enzyme.loop(%cst_1 : tensor) to(%cst_5 : tensor) step(%cst_0 : tensor) iter_args(%1, %result, %8#1, %8#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>) -> tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> { +// CHECK-NEXT: %11 = "enzyme.broadcast"(%cst_6) <{shape = array}> : (tensor) -> tensor<2xf64> +// CHECK-NEXT: %12 = "enzyme.broadcast"(%cst) <{shape = array}> : (tensor) -> tensor<2xf64> +// CHECK-NEXT: %13:4 = enzyme.for_loop(%cst_1 : tensor) to(%cst_5 : tensor) step(%cst_0 : tensor) iter_args(%1, %5, %10#1, %10#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>) -> tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> { // CHECK-NEXT: ^bb0(%arg3: tensor, %arg4: tensor<2xf64>, %arg5: tensor<2xf64>, %arg6: tensor<2xf64>, %arg7: tensor<2xui64>): -// CHECK-NEXT: %23 = arith.mulf %10, %arg6 : tensor<2xf64> -// CHECK-NEXT: %24 = arith.subf %arg5, %23 : tensor<2xf64> -// CHECK-NEXT: %25 = enzyme.cholesky_solve %cst_7, %24 : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> -// CHECK-NEXT: %26 = arith.mulf %9, %25 : tensor<2xf64> -// CHECK-NEXT: %27 = arith.addf %arg4, %26 : tensor<2xf64> -// CHECK-NEXT: %28:2 = enzyme.autodiff_region(%27, %cst_3) { +// CHECK-NEXT: %25 = arith.mulf %12, %arg6 : tensor<2xf64> +// CHECK-NEXT: %26 = arith.subf %arg5, %25 : tensor<2xf64> +// CHECK-NEXT: %27 = enzyme.dot %cst_7, %26 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %28 = arith.mulf %11, %27 : tensor<2xf64> +// CHECK-NEXT: %29 = arith.addf %arg4, %28 : tensor<2xf64> +// CHECK-NEXT: %30:2 = enzyme.autodiff_region(%29, %cst_3) { // CHECK-NEXT: ^bb0(%arg8: tensor<2xf64>): -// CHECK-NEXT: %31:3 = func.call @test.update_0(%0, %arg8, %arg7, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) -// CHECK-NEXT: %32 = arith.negf %31#1 : tensor -// CHECK-NEXT: enzyme.yield %32, %31#2 : tensor, tensor<2xui64> +// CHECK-NEXT: %33:3 = func.call @test.update_0(%0, %arg8, %arg7, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) +// CHECK-NEXT: %34 = arith.negf %33#1 : tensor +// CHECK-NEXT: enzyme.yield %34, %33#2 : tensor, tensor<2xui64> // CHECK-NEXT: } attributes {activity = [#enzyme], ret_activity = [#enzyme, #enzyme]} : (tensor<2xf64>, tensor) -> (tensor<2xui64>, tensor<2xf64>) -// CHECK-NEXT: %29 = arith.mulf %10, %28#1 : tensor<2xf64> -// CHECK-NEXT: %30 = arith.subf %24, %29 : tensor<2xf64> -// CHECK-NEXT: enzyme.yield %27, %30, %28#1, %28#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> +// CHECK-NEXT: %31 = arith.mulf %12, %30#1 : tensor<2xf64> +// CHECK-NEXT: %32 = arith.subf %26, %31 : tensor<2xf64> +// CHECK-NEXT: enzyme.yield %29, %32, %30#1, %30#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> // CHECK-NEXT: } -// CHECK-NEXT: %12:3 = call @test.update(%0, %11#0, %11#3, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) -// CHECK-NEXT: %13 = arith.negf %12#1 : tensor -// CHECK-NEXT: %14 = enzyme.cholesky_solve %cst_7, %11#1 : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> -// CHECK-NEXT: %15 = enzyme.dot %11#1, %14 : (tensor<2xf64>, tensor<2xf64>) -> tensor -// CHECK-NEXT: %16 = arith.mulf %15, %cst_2 : tensor -// CHECK-NEXT: %17 = arith.addf %13, %16 : tensor -// CHECK-NEXT: %18 = arith.subf %7, %17 : tensor -// CHECK-NEXT: %19 = math.exp %18 : tensor -// CHECK-NEXT: %20 = arith.minimumf %19, %cst_3 : tensor -// CHECK-NEXT: %output_rng_state_8, %result_9 = enzyme.random %12#2, %cst_4, %cst_3 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) -// CHECK-NEXT: %21 = arith.cmpf olt, %result_9, %20 : tensor -// CHECK-NEXT: %22 = enzyme.selectTrace %21, %12#0, %0 : tensor -// CHECK-NEXT: return %22, %21, %output_rng_state_8 : !enzyme.Trace, tensor, tensor<2xui64> +// CHECK-NEXT: %14:3 = call @test.update(%0, %13#0, %13#3, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) +// CHECK-NEXT: %15 = arith.negf %14#1 : tensor +// CHECK-NEXT: %16 = enzyme.dot %cst_7, %13#1 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %17 = enzyme.dot %13#1, %16 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor +// CHECK-NEXT: %18 = arith.mulf %17, %cst_2 : tensor +// CHECK-NEXT: %19 = arith.addf %15, %18 : tensor +// CHECK-NEXT: %20 = arith.subf %9, %19 : tensor +// CHECK-NEXT: %21 = math.exp %20 : tensor +// CHECK-NEXT: %22 = arith.minimumf %21, %cst_3 : tensor +// CHECK-NEXT: %output_rng_state_8, %result_9 = enzyme.random %14#2, %cst_4, %cst_3 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) +// CHECK-NEXT: %23 = arith.cmpf olt, %result_9, %22 : tensor +// CHECK-NEXT: %24 = enzyme.selectTrace %23, %14#0, %0 : tensor +// CHECK-NEXT: return %24, %23, %output_rng_state_8 : !enzyme.Trace, tensor, tensor<2xui64> // CHECK-NEXT: } // CHECK: func.func @test.update(%arg0: !enzyme.Trace, %arg1: tensor<2xf64>, %arg2: tensor<2xui64>, %arg3: tensor, %arg4: tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) { @@ -100,4 +102,4 @@ module { // CHECK-NEXT: %9 = enzyme.addWeightToTrace(%7 : tensor) into %8 // CHECK-NEXT: %10 = enzyme.addRetvalToTrace(%5 : tensor) into %9 // CHECK-NEXT: return %10, %7, %arg2 : !enzyme.Trace, tensor, tensor<2xui64> -// CHECK-NEXT: } +// CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/MLIR/ProbProg/nuts.mlir b/enzyme/test/MLIR/ProbProg/nuts.mlir new file mode 100644 index 00000000000..a136858804d --- /dev/null +++ b/enzyme/test/MLIR/ProbProg/nuts.mlir @@ -0,0 +1,184 @@ +// RUN: %eopt --probprog %s --mlir-print-ir-after=probprog | FileCheck %s + +module { + func.func private @normal(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) + func.func private @logpdf(%x : tensor, %mean : tensor, %stddev : tensor) -> tensor + + func.func @test(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (tensor<2xui64>, tensor) { + %s:2 = enzyme.sample @normal(%rng, %mean, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<1>, name="s" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %t:2 = enzyme.sample @normal(%s#0, %s#1, %stddev) { logpdf = @logpdf, symbol = #enzyme.symbol<2>, name="t" } : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + return %t#0, %t#1 : tensor<2xui64>, tensor + } + + func.func @nuts(%rng : tensor<2xui64>, %mean : tensor, %stddev : tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) { + %init_trace = enzyme.initTrace : !enzyme.Trace + + %inverse_mass_matrix = arith.constant dense<[[1.0, 0.0], [0.0, 1.0]]> : tensor<2x2xf64> + %step_size = arith.constant dense<0.1> : tensor + + %res:3 = enzyme.mcmc algorithm = NUTS @test(%rng, %mean, %stddev) given %init_trace + inverse_mass_matrix = %inverse_mass_matrix : tensor<2x2xf64> + step_size = %step_size : tensor + { name = "nuts", selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]] } : (tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) + return %res#0, %res#1, %res#2 : !enzyme.Trace, tensor, tensor<2xui64> + } +} + +// CHECK: func.func @nuts(%arg0: tensor<2xui64>, %arg1: tensor, %arg2: tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) { +// CHECK-NEXT: %cst = arith.constant dense<-1.000000e-01> : tensor +// CHECK-NEXT: %cst_0 = arith.constant dense<1.000000e+03> : tensor +// CHECK-NEXT: %cst_1 = arith.constant dense<10> : tensor +// CHECK-NEXT: %cst_2 = arith.constant dense : tensor +// CHECK-NEXT: %cst_3 = arith.constant dense : tensor +// CHECK-NEXT: %cst_4 = arith.constant dense<1> : tensor +// CHECK-NEXT: %cst_5 = arith.constant dense<0> : tensor +// CHECK-NEXT: %cst_6 = arith.constant dense<5.000000e-01> : tensor +// CHECK-NEXT: %cst_7 = arith.constant dense<1.000000e+00> : tensor +// CHECK-NEXT: %cst_8 = arith.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: %cst_9 = arith.constant dense<1.000000e-01> : tensor +// CHECK-NEXT: %cst_10 = arith.constant dense<{{\[}}[1.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00]{{\]}}> : tensor<2x2xf64> +// CHECK-NEXT: %0 = enzyme.initTrace : !enzyme.Trace +// CHECK-NEXT: %1 = enzyme.getFlattenedSamplesFromTrace %0 {selection = {{\[}}[#enzyme.symbol<1>], [#enzyme.symbol<2>]{{\]}}} : tensor<2xf64> +// CHECK-NEXT: %2 = enzyme.getWeightFromTrace %0 : tensor +// CHECK-NEXT: %3 = arith.negf %2 : tensor +// CHECK-NEXT: %4 = enzyme.cholesky_solve %cst_10, %cst_10 : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64> +// CHECK-NEXT: %output_rng_state, %result = enzyme.random %arg0, %cst_8, %cst_7 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<2xf64>) +// CHECK-NEXT: %5 = enzyme.dot %4, %result {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %6 = enzyme.dot %cst_10, %5 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %7 = enzyme.dot %5, %6 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor +// CHECK-NEXT: %8 = arith.mulf %7, %cst_6 : tensor +// CHECK-NEXT: %9 = arith.addf %3, %8 : tensor +// CHECK-NEXT: %10:2 = enzyme.autodiff_region(%1, %cst_7) { +// CHECK-NEXT: ^bb0(%arg3: tensor<2xf64>): +// CHECK-NEXT: %18:3 = func.call @test.update_1(%0, %arg3, %output_rng_state, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) +// CHECK-NEXT: %19 = arith.negf %18#1 : tensor +// CHECK-NEXT: enzyme.yield %19, %18#2 : tensor, tensor<2xui64> +// CHECK-NEXT: } attributes {activity = [#enzyme], ret_activity = [#enzyme, #enzyme]} : (tensor<2xf64>, tensor) -> (tensor<2xui64>, tensor<2xf64>) +// CHECK-NEXT: %11:18 = enzyme.while_loop(%1, %5, %10#1, %1, %5, %10#1, %1, %10#1, %3, %9, %cst_5, %cst_8, %cst_3, %cst_3, %cst_7, %cst_4, %5, %10#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<2xf64>, tensor<2xui64>) -> tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<2xf64>, tensor<2xui64> condition { +// CHECK-NEXT: ^bb0(%arg3: tensor<2xf64>, %arg4: tensor<2xf64>, %arg5: tensor<2xf64>, %arg6: tensor<2xf64>, %arg7: tensor<2xf64>, %arg8: tensor<2xf64>, %arg9: tensor<2xf64>, %arg10: tensor<2xf64>, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor<2xf64>, %arg20: tensor<2xui64>): +// CHECK-NEXT: %18 = arith.cmpi slt, %arg13, %cst_1 : tensor +// CHECK-NEXT: %19 = arith.xori %arg15, %cst_2 : tensor +// CHECK-NEXT: %20 = arith.xori %arg16, %cst_2 : tensor +// CHECK-NEXT: %21 = arith.andi %18, %19 : tensor +// CHECK-NEXT: %22 = arith.andi %21, %20 : tensor +// CHECK-NEXT: enzyme.yield %22 : tensor +// CHECK-NEXT: } body { +// CHECK-NEXT: ^bb0(%arg3: tensor<2xf64>, %arg4: tensor<2xf64>, %arg5: tensor<2xf64>, %arg6: tensor<2xf64>, %arg7: tensor<2xf64>, %arg8: tensor<2xf64>, %arg9: tensor<2xf64>, %arg10: tensor<2xf64>, %arg11: tensor, %arg12: tensor, %arg13: tensor, %arg14: tensor, %arg15: tensor, %arg16: tensor, %arg17: tensor, %arg18: tensor, %arg19: tensor<2xf64>, %arg20: tensor<2xui64>): +// CHECK-NEXT: %18:2 = enzyme.randomSplit %arg20 : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>) +// CHECK-NEXT: %output_rng_state_13, %result_14 = enzyme.random %18#0, %cst_8, %cst_7 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) +// CHECK-NEXT: %19 = arith.cmpf ogt, %result_14, %cst_6 : tensor +// CHECK-NEXT: %20 = arith.addi %arg13, %cst_4 : tensor +// CHECK-NEXT: %21:18 = enzyme.while_loop(%arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %18#1 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<2xf64>, tensor<2xui64>) -> tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<2xf64>, tensor<2xui64> condition { +// CHECK-NEXT: ^bb0(%arg21: tensor<2xf64>, %arg22: tensor<2xf64>, %arg23: tensor<2xf64>, %arg24: tensor<2xf64>, %arg25: tensor<2xf64>, %arg26: tensor<2xf64>, %arg27: tensor<2xf64>, %arg28: tensor<2xf64>, %arg29: tensor, %arg30: tensor, %arg31: tensor, %arg32: tensor, %arg33: tensor, %arg34: tensor, %arg35: tensor, %arg36: tensor, %arg37: tensor<2xf64>, %arg38: tensor<2xui64>): +// CHECK-NEXT: %50 = arith.cmpi slt, %arg31, %20 : tensor +// CHECK-NEXT: %51 = arith.xori %arg33, %cst_2 : tensor +// CHECK-NEXT: %52 = arith.xori %arg34, %cst_2 : tensor +// CHECK-NEXT: %53 = arith.andi %50, %51 : tensor +// CHECK-NEXT: %54 = arith.andi %53, %52 : tensor +// CHECK-NEXT: enzyme.yield %54 : tensor +// CHECK-NEXT: } body { +// CHECK-NEXT: ^bb0(%arg21: tensor<2xf64>, %arg22: tensor<2xf64>, %arg23: tensor<2xf64>, %arg24: tensor<2xf64>, %arg25: tensor<2xf64>, %arg26: tensor<2xf64>, %arg27: tensor<2xf64>, %arg28: tensor<2xf64>, %arg29: tensor, %arg30: tensor, %arg31: tensor, %arg32: tensor, %arg33: tensor, %arg34: tensor, %arg35: tensor, %arg36: tensor, %arg37: tensor<2xf64>, %arg38: tensor<2xui64>): +// CHECK-NEXT: %50 = "enzyme.broadcast"(%19) <{shape = array}> : (tensor) -> tensor<2xi1> +// CHECK-NEXT: %51 = arith.select %50, %arg24, %arg21 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %52 = arith.select %50, %arg25, %arg22 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %53 = arith.select %50, %arg26, %arg23 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %54:3 = enzyme.randomSplit %arg38 : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>, tensor<2xui64>) +// CHECK-NEXT: %55 = arith.select %19, %cst_9, %cst : tensor, tensor +// CHECK-NEXT: %56 = "enzyme.broadcast"(%55) <{shape = array}> : (tensor) -> tensor<2xf64> +// CHECK-NEXT: %57 = arith.mulf %55, %cst_6 : tensor +// CHECK-NEXT: %58 = "enzyme.broadcast"(%57) <{shape = array}> : (tensor) -> tensor<2xf64> +// CHECK-NEXT: %59 = arith.mulf %58, %53 : tensor<2xf64> +// CHECK-NEXT: %60 = arith.subf %52, %59 : tensor<2xf64> +// CHECK-NEXT: %61 = enzyme.dot %cst_10, %60 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %62 = arith.mulf %56, %61 : tensor<2xf64> +// CHECK-NEXT: %63 = arith.addf %51, %62 : tensor<2xf64> +// CHECK-NEXT: %64:3 = enzyme.autodiff_region(%63, %cst_7) { +// CHECK-NEXT: ^bb0(%arg39: tensor<2xf64>): +// CHECK-NEXT: %106:3 = func.call @test.update_0(%0, %arg39, %54#0, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) +// CHECK-NEXT: %107 = arith.negf %106#1 : tensor +// CHECK-NEXT: enzyme.yield %107, %106#2 : tensor, tensor<2xui64> +// CHECK-NEXT: } attributes {activity = [#enzyme], ret_activity = [#enzyme, #enzyme]} : (tensor<2xf64>, tensor) -> (tensor, tensor<2xui64>, tensor<2xf64>) +// CHECK-NEXT: %65 = arith.mulf %58, %64#2 : tensor<2xf64> +// CHECK-NEXT: %66 = arith.subf %60, %65 : tensor<2xf64> +// CHECK-NEXT: %67 = enzyme.dot %cst_10, %66 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %68 = enzyme.dot %66, %67 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor +// CHECK-NEXT: %69 = arith.mulf %68, %cst_6 : tensor +// CHECK-NEXT: %70 = arith.addf %64#0, %69 : tensor +// CHECK-NEXT: %71 = arith.subf %70, %9 : tensor +// CHECK-NEXT: %72 = arith.cmpf ogt, %71, %cst_0 : tensor +// CHECK-NEXT: %73 = arith.negf %71 : tensor +// CHECK-NEXT: %74 = math.exp %73 : tensor +// CHECK-NEXT: %75 = arith.minimumf %74, %cst_7 : tensor +// CHECK-NEXT: %76 = arith.select %50, %arg21, %63 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %77 = arith.select %50, %arg22, %66 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %78 = arith.select %50, %arg23, %64#2 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %79 = arith.select %50, %63, %arg24 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %80 = arith.select %50, %66, %arg25 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %81 = arith.select %50, %64#2, %arg26 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %82 = enzyme.log_add_exp %arg32, %73 : (tensor, tensor) -> tensor +// CHECK-NEXT: %83 = arith.subf %73, %arg32 : tensor +// CHECK-NEXT: %84 = arith.negf %83 : tensor +// CHECK-NEXT: %85 = math.exp %84 : tensor +// CHECK-NEXT: %86 = arith.addf %85, %cst_7 : tensor +// CHECK-NEXT: %87 = arith.divf %cst_7, %86 : tensor +// CHECK-NEXT: %output_rng_state_17, %result_18 = enzyme.random %54#1, %cst_8, %cst_7 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) +// CHECK-NEXT: %88 = arith.cmpf olt, %result_18, %87 : tensor +// CHECK-NEXT: %89 = "enzyme.broadcast"(%88) <{shape = array}> : (tensor) -> tensor<2xi1> +// CHECK-NEXT: %90 = arith.select %89, %63, %arg27 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %91 = arith.select %89, %64#2, %arg28 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %92 = arith.select %88, %64#0, %arg29 : tensor, tensor +// CHECK-NEXT: %93 = arith.select %88, %70, %arg30 : tensor, tensor +// CHECK-NEXT: %94 = arith.addi %arg31, %cst_4 : tensor +// CHECK-NEXT: %95 = arith.ori %arg34, %72 : tensor +// CHECK-NEXT: %96 = arith.addf %arg35, %75 : tensor +// CHECK-NEXT: %97 = arith.addi %arg36, %cst_4 : tensor +// CHECK-NEXT: %98 = arith.addf %arg37, %66 : tensor<2xf64> +// CHECK-NEXT: %99 = enzyme.dot %cst_10, %77 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %100 = enzyme.dot %cst_10, %80 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: %101 = enzyme.dot %98, %99 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor +// CHECK-NEXT: %102 = enzyme.dot %98, %100 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor +// CHECK-NEXT: %103 = arith.cmpf olt, %101, %cst_8 : tensor +// CHECK-NEXT: %104 = arith.cmpf olt, %102, %cst_8 : tensor +// CHECK-NEXT: %105 = arith.ori %103, %104 : tensor +// CHECK-NEXT: enzyme.yield %76, %77, %78, %79, %80, %81, %90, %91, %92, %93, %94, %82, %105, %95, %96, %97, %98, %arg38 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<2xf64>, tensor<2xui64> +// CHECK-NEXT: } +// CHECK-NEXT: %22:2 = enzyme.randomSplit %21#17 : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>) +// CHECK-NEXT: %23 = "enzyme.broadcast"(%19) <{shape = array}> : (tensor) -> tensor<2xi1> +// CHECK-NEXT: %24 = arith.select %23, %arg3, %21#0 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %25 = arith.select %23, %arg4, %21#1 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %26 = arith.select %23, %arg5, %21#2 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %27 = arith.select %23, %21#3, %arg6 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %28 = arith.select %23, %21#4, %arg7 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %29 = arith.select %23, %21#5, %arg8 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %30 = enzyme.log_add_exp %arg14, %21#11 : (tensor, tensor) -> tensor +// CHECK-NEXT: %31 = arith.subf %21#11, %arg14 : tensor +// CHECK-NEXT: %32 = arith.negf %31 : tensor +// CHECK-NEXT: %33 = math.exp %32 : tensor +// CHECK-NEXT: %34 = arith.addf %33, %cst_7 : tensor +// CHECK-NEXT: %35 = arith.divf %cst_7, %34 : tensor +// CHECK-NEXT: %36 = arith.ori %21#12, %21#13 : tensor +// CHECK-NEXT: %37 = arith.select %36, %cst_8, %35 : tensor, tensor +// CHECK-NEXT: %output_rng_state_15, %result_16 = enzyme.random %22#0, %cst_8, %cst_7 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) +// CHECK-NEXT: %38 = arith.cmpf olt, %result_16, %37 : tensor +// CHECK-NEXT: %39 = "enzyme.broadcast"(%38) <{shape = array}> : (tensor) -> tensor<2xi1> +// CHECK-NEXT: %40 = arith.select %39, %21#6, %arg9 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %41 = arith.select %39, %21#7, %arg10 : tensor<2xi1>, tensor<2xf64> +// CHECK-NEXT: %42 = arith.select %38, %21#8, %arg11 : tensor, tensor +// CHECK-NEXT: %43 = arith.select %38, %21#9, %arg12 : tensor, tensor +// CHECK-NEXT: %44 = arith.addi %arg13, %cst_4 : tensor +// CHECK-NEXT: %45 = arith.ori %arg15, %21#12 : tensor +// CHECK-NEXT: %46 = arith.ori %arg16, %21#13 : tensor +// CHECK-NEXT: %47 = arith.addf %arg17, %21#14 : tensor +// CHECK-NEXT: %48 = arith.addi %arg18, %21#15 : tensor +// CHECK-NEXT: %49 = arith.addf %arg19, %21#16 : tensor<2xf64> +// CHECK-NEXT: enzyme.yield %24, %25, %26, %27, %28, %29, %40, %41, %42, %43, %44, %30, %45, %46, %47, %48, %49, %22#1 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor<2xf64>, tensor<2xui64> +// CHECK-NEXT: } +// CHECK-NEXT: %12:3 = call @test.update(%0, %11#6, %11#17, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) +// CHECK-NEXT: %13 = arith.subf %9, %11#9 : tensor +// CHECK-NEXT: %14 = math.exp %13 : tensor +// CHECK-NEXT: %15 = arith.minimumf %14, %cst_7 : tensor +// CHECK-NEXT: %output_rng_state_11, %result_12 = enzyme.random %12#2, %cst_8, %cst_7 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) +// CHECK-NEXT: %16 = arith.cmpf olt, %result_12, %15 : tensor +// CHECK-NEXT: %17 = enzyme.selectTrace %16, %12#0, %0 : tensor +// CHECK-NEXT: return %17, %16, %output_rng_state_11 : !enzyme.Trace, tensor, tensor<2xui64> +// CHECK-NEXT: }