Skip to content

Commit 94ba9b7

Browse files
committed
take inverse_mass_matrix instead
1 parent 5f098c2 commit 94ba9b7

File tree

3 files changed

+183
-75
lines changed

3 files changed

+183
-75
lines changed

enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -877,7 +877,7 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
877877
and the 0th operand in results is the updated RNG state.
878878

879879
Optional HMC-specific parameters:
880-
- mass: Mass matrix (identity assumed if not provided)
880+
- inverse_mass_matrix: Inverse mass matrix (identity assumed if not provided).
881881
- step_size: Leapfrong integration step size
882882
- num_steps: Number of leapfrog steps
883883
- initial_momentum: deterministic initial momentum (debug)
@@ -889,7 +889,7 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
889889
Variadic<AnyType>:$inputs,
890890
Trace:$original_trace,
891891
AddressArrayAttr:$selection,
892-
Optional<AnyType>:$mass,
892+
Optional<AnyType>:$inverse_mass_matrix,
893893
Optional<AnyType>:$step_size,
894894
Optional<AnyType>:$num_steps,
895895
Optional<AnyType>:$initial_momentum,
@@ -900,7 +900,7 @@ def MCMCOp : Enzyme_Op<"mcmc", [DeclareOpInterfaceMethods<SymbolUserOpInterface>
900900

901901
let assemblyFormat = [{
902902
`algorithm` `=` $alg $fn `(` $inputs `)` `given` $original_trace
903-
(`mass` `=` $mass^ `:` type($mass))?
903+
(`inverse_mass_matrix` `=` $inverse_mass_matrix^ `:` type($inverse_mass_matrix))?
904904
(`step_size` `=` $step_size^ `:` type($step_size))?
905905
(`num_steps` `=` $num_steps^ `:` type($num_steps))?
906906
(`initial_momentum` `=` $initial_momentum^ `:` type($initial_momentum))?

enzyme/Enzyme/MLIR/Passes/ProbProgMLIRPass.cpp

Lines changed: 138 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,23 @@ namespace enzyme {
3939

4040
namespace {
4141

42+
static Value createIdentityMatrix(OpBuilder &builder, Location loc,
43+
RankedTensorType matrixType) {
44+
auto shape = matrixType.getShape();
45+
assert(shape.size() == 2 && shape[0] == shape[1] &&
46+
"Identity matrix must be square");
47+
int64_t n = shape[0];
48+
49+
SmallVector<double> identityData(n * n, 0.0);
50+
for (int64_t i = 0; i < n; ++i) {
51+
identityData[i * n + i] = 1.0;
52+
}
53+
54+
return builder.create<arith::ConstantOp>(
55+
loc, matrixType,
56+
DenseElementsAttr::get(matrixType, ArrayRef<double>(identityData)));
57+
}
58+
4259
static bool computePositionSizeForAddress(Operation *op,
4360
FunctionOpInterface func,
4461
ArrayRef<Attribute> address,
@@ -455,7 +472,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
455472
return failure();
456473
}
457474

458-
Value mass = mcmcOp.getMass();
475+
Value invMass = mcmcOp.getInverseMassMatrix();
459476
Value stepSize = mcmcOp.getStepSize();
460477
Value numSteps = mcmcOp.getNumSteps();
461478

@@ -508,18 +525,52 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
508525
p0 = initialMomentum;
509526
rng1 = rngState;
510527
} else {
511-
if (mass) {
512-
auto randomOp = enzyme::RandomOp::create(
513-
rewriter, loc, TypeRange{rngState.getType(), positionType},
514-
rngState, zeroConst, mass,
515-
enzyme::RngDistributionAttr::get(
516-
rewriter.getContext(), enzyme::RngDistribution::MULTINORMAL));
517-
rng1 = randomOp.getOutputRngState();
518-
p0 = randomOp.getResult();
528+
if (invMass) {
529+
auto invMassType = cast<RankedTensorType>(invMass.getType());
530+
if (invMassType.getRank() == 1) {
531+
// Diagonal case: p0 = 1 / sqrt(invMass) * eps where eps ~ N(0, I)
532+
auto sqrtInvMass = rewriter.create<math::SqrtOp>(loc, invMass);
533+
auto onesVector = rewriter.create<arith::ConstantOp>(
534+
loc, invMassType,
535+
DenseElementsAttr::get(invMassType,
536+
rewriter.getF64FloatAttr(1.0)));
537+
auto massMatrixSqrt =
538+
rewriter.create<arith::DivFOp>(loc, onesVector, sqrtInvMass);
539+
auto randomOp = rewriter.create<enzyme::RandomOp>(
540+
loc, TypeRange{rngState.getType(), positionType}, rngState,
541+
zeroConst, oneConst,
542+
enzyme::RngDistributionAttr::get(
543+
rewriter.getContext(), enzyme::RngDistribution::NORMAL));
544+
rng1 = randomOp.getOutputRngState();
545+
Value eps = randomOp.getResult();
546+
p0 = rewriter.create<arith::MulFOp>(loc, massMatrixSqrt, eps);
547+
} else {
548+
// Dense case: p0 = mass_matrix_sqrt @ eps where eps ~ N(0, I)
549+
auto identityMatrix =
550+
createIdentityMatrix(rewriter, loc, invMassType);
551+
auto massMatrixSqrt = rewriter.create<enzyme::CholeskySolveOp>(
552+
loc, invMassType, invMass, identityMatrix);
553+
auto randomOp = rewriter.create<enzyme::RandomOp>(
554+
loc, TypeRange{rngState.getType(), positionType}, rngState,
555+
zeroConst, oneConst,
556+
enzyme::RngDistributionAttr::get(
557+
rewriter.getContext(), enzyme::RngDistribution::NORMAL));
558+
rng1 = randomOp.getOutputRngState();
559+
Value eps = randomOp.getResult();
560+
p0 = rewriter.create<enzyme::DotOp>(
561+
loc, positionType, massMatrixSqrt, eps,
562+
/*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
563+
/*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
564+
/*lhs_contracting_dimensions=*/
565+
rewriter.getDenseI64ArrayAttr({1}),
566+
/*rhs_contracting_dimensions=*/
567+
rewriter.getDenseI64ArrayAttr({0}));
568+
}
519569
} else {
520-
auto randomOp = enzyme::RandomOp::create(
521-
rewriter, loc, TypeRange{rngState.getType(), positionType},
522-
rngState, zeroConst, oneConst,
570+
// Assume identity mass matrix: p0 ~ N(0, I)
571+
auto randomOp = rewriter.create<enzyme::RandomOp>(
572+
loc, TypeRange{rngState.getType(), positionType}, rngState,
573+
zeroConst, oneConst,
523574
enzyme::RngDistributionAttr::get(
524575
rewriter.getContext(), enzyme::RngDistribution::NORMAL));
525576
rng1 = randomOp.getOutputRngState();
@@ -531,19 +582,41 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
531582
rewriter, loc, tensorType,
532583
DenseElementsAttr::get(tensorType, rewriter.getF64FloatAttr(0.5)));
533584

534-
// 4. Compute initial kinetic energy K0 = 0.5 * p^T * M^{-1} * p
585+
// 4. Compute initial kinetic energy K0 = 0.5 * p^T * M^-1 * p
535586
Value K0;
536-
if (mass) {
537-
auto MInvP0 = enzyme::CholeskySolveOp::create(rewriter, loc,
538-
positionType, mass, p0);
539-
auto p0DotMInvP =
540-
enzyme::DotOp::create(rewriter, loc, tensorType, p0, MInvP0);
587+
if (invMass) {
588+
auto invMassType = cast<RankedTensorType>(invMass.getType());
589+
Value invMassP0;
590+
591+
if (invMassType.getRank() == 1) {
592+
invMassP0 = rewriter.create<arith::MulFOp>(loc, invMass, p0);
593+
} else {
594+
invMassP0 = rewriter.create<enzyme::DotOp>(
595+
loc, positionType, invMass, p0,
596+
/*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
597+
/*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
598+
/*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({1}),
599+
/*rhs_contracting_dimensions=*/
600+
rewriter.getDenseI64ArrayAttr({0}));
601+
}
602+
603+
auto p0DotInvMassP0 = rewriter.create<enzyme::DotOp>(
604+
loc, tensorType, p0, invMassP0,
605+
/*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
606+
/*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
607+
/*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}),
608+
/*rhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}));
541609
K0 = conditionalDump(
542610
rewriter, loc,
543-
arith::MulFOp::create(rewriter, loc, halfConst, p0DotMInvP),
611+
rewriter.create<arith::MulFOp>(loc, halfConst, p0DotInvMassP0),
544612
"HMC: initial kinetic energy K0");
545613
} else {
546-
auto p0DotP0 = enzyme::DotOp::create(rewriter, loc, tensorType, p0, p0);
614+
auto p0DotP0 = rewriter.create<enzyme::DotOp>(
615+
loc, tensorType, p0, p0,
616+
/*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
617+
/*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
618+
/*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}),
619+
/*rhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}));
547620
K0 = conditionalDump(
548621
rewriter, loc,
549622
arith::MulFOp::create(rewriter, loc, halfConst, p0DotP0),
@@ -647,11 +720,22 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
647720
rewriter, loc, arith::SubFOp::create(rewriter, loc, p, deltaP1),
648721
"Leapfrog: momentum p(t + eps/2)");
649722

650-
// 6.2 Full step on position: q += eps * M^{-1} * p1
723+
// 6.2 Full step on position: q += eps * M^-1 * p1
651724
Value v1;
652-
if (mass) {
653-
v1 = enzyme::CholeskySolveOp::create(rewriter, loc, positionType, mass,
654-
p1);
725+
if (invMass) {
726+
auto invMassType = cast<RankedTensorType>(invMass.getType());
727+
728+
if (invMassType.getRank() == 1) {
729+
v1 = rewriter.create<arith::MulFOp>(loc, invMass, p1);
730+
} else {
731+
v1 = rewriter.create<enzyme::DotOp>(
732+
loc, positionType, invMass, p1,
733+
/*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
734+
/*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
735+
/*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({1}),
736+
/*rhs_contracting_dimensions=*/
737+
rewriter.getDenseI64ArrayAttr({0}));
738+
}
655739
} else {
656740
v1 = p1;
657741
}
@@ -738,19 +822,41 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
738822
rewriter, loc, arith::NegFOp::create(rewriter, loc, weight1),
739823
"HMC: final potential energy U1");
740824

741-
// K1 = 0.5 * pL^T * M^{-1} * pL
825+
// K1 = 0.5 * pL^T * M^-1 * pL
742826
Value K1;
743-
if (mass) {
744-
auto MInvPL = enzyme::CholeskySolveOp::create(rewriter, loc,
745-
positionType, mass, pL);
746-
auto pLDotMInvPL =
747-
enzyme::DotOp::create(rewriter, loc, tensorType, pL, MInvPL);
827+
if (invMass) {
828+
auto invMassType = cast<RankedTensorType>(invMass.getType());
829+
Value invMassPL;
830+
831+
if (invMassType.getRank() == 1) {
832+
invMassPL = rewriter.create<arith::MulFOp>(loc, invMass, pL);
833+
} else {
834+
invMassPL = rewriter.create<enzyme::DotOp>(
835+
loc, positionType, invMass, pL,
836+
/*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
837+
/*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
838+
/*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({1}),
839+
/*rhs_contracting_dimensions=*/
840+
rewriter.getDenseI64ArrayAttr({0}));
841+
}
842+
843+
auto pLDotInvMassPL = rewriter.create<enzyme::DotOp>(
844+
loc, tensorType, pL, invMassPL,
845+
/*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
846+
/*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
847+
/*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}),
848+
/*rhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}));
748849
K1 = conditionalDump(
749850
rewriter, loc,
750-
arith::MulFOp::create(rewriter, loc, halfConst, pLDotMInvPL),
851+
rewriter.create<arith::MulFOp>(loc, halfConst, pLDotInvMassPL),
751852
"HMC: final kinetic energy K1");
752853
} else {
753-
auto pLDotPL = enzyme::DotOp::create(rewriter, loc, tensorType, pL, pL);
854+
auto pLDotPL = rewriter.create<enzyme::DotOp>(
855+
loc, tensorType, pL, pL,
856+
/*lhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
857+
/*rhs_batching_dimensions=*/rewriter.getDenseI64ArrayAttr({}),
858+
/*lhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}),
859+
/*rhs_contracting_dimensions=*/rewriter.getDenseI64ArrayAttr({0}));
754860
K1 = conditionalDump(
755861
rewriter, loc,
756862
arith::MulFOp::create(rewriter, loc, halfConst, pLDotPL),

0 commit comments

Comments
 (0)