@@ -39,6 +39,23 @@ namespace enzyme {
3939
4040namespace {
4141
42+ static Value createIdentityMatrix (OpBuilder &builder, Location loc,
43+ RankedTensorType matrixType) {
44+ auto shape = matrixType.getShape ();
45+ assert (shape.size () == 2 && shape[0 ] == shape[1 ] &&
46+ " Identity matrix must be square" );
47+ int64_t n = shape[0 ];
48+
49+ SmallVector<double > identityData (n * n, 0.0 );
50+ for (int64_t i = 0 ; i < n; ++i) {
51+ identityData[i * n + i] = 1.0 ;
52+ }
53+
54+ return builder.create <arith::ConstantOp>(
55+ loc, matrixType,
56+ DenseElementsAttr::get (matrixType, ArrayRef<double >(identityData)));
57+ }
58+
4259static bool computePositionSizeForAddress (Operation *op,
4360 FunctionOpInterface func,
4461 ArrayRef<Attribute> address,
@@ -455,7 +472,7 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
455472 return failure ();
456473 }
457474
458- Value mass = mcmcOp.getMass ();
475+ Value invMass = mcmcOp.getInverseMassMatrix ();
459476 Value stepSize = mcmcOp.getStepSize ();
460477 Value numSteps = mcmcOp.getNumSteps ();
461478
@@ -508,18 +525,52 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
508525 p0 = initialMomentum;
509526 rng1 = rngState;
510527 } else {
511- if (mass) {
512- auto randomOp = enzyme::RandomOp::create (
513- rewriter, loc, TypeRange{rngState.getType (), positionType},
514- rngState, zeroConst, mass,
515- enzyme::RngDistributionAttr::get (
516- rewriter.getContext (), enzyme::RngDistribution::MULTINORMAL));
517- rng1 = randomOp.getOutputRngState ();
518- p0 = randomOp.getResult ();
528+ if (invMass) {
529+ auto invMassType = cast<RankedTensorType>(invMass.getType ());
530+ if (invMassType.getRank () == 1 ) {
531+ // Diagonal case: p0 = 1 / sqrt(invMass) * eps where eps ~ N(0, I)
532+ auto sqrtInvMass = rewriter.create <math::SqrtOp>(loc, invMass);
533+ auto onesVector = rewriter.create <arith::ConstantOp>(
534+ loc, invMassType,
535+ DenseElementsAttr::get (invMassType,
536+ rewriter.getF64FloatAttr (1.0 )));
537+ auto massMatrixSqrt =
538+ rewriter.create <arith::DivFOp>(loc, onesVector, sqrtInvMass);
539+ auto randomOp = rewriter.create <enzyme::RandomOp>(
540+ loc, TypeRange{rngState.getType (), positionType}, rngState,
541+ zeroConst, oneConst,
542+ enzyme::RngDistributionAttr::get (
543+ rewriter.getContext (), enzyme::RngDistribution::NORMAL));
544+ rng1 = randomOp.getOutputRngState ();
545+ Value eps = randomOp.getResult ();
546+ p0 = rewriter.create <arith::MulFOp>(loc, massMatrixSqrt, eps);
547+ } else {
548+ // Dense case: p0 = mass_matrix_sqrt @ eps where eps ~ N(0, I)
549+ auto identityMatrix =
550+ createIdentityMatrix (rewriter, loc, invMassType);
551+ auto massMatrixSqrt = rewriter.create <enzyme::CholeskySolveOp>(
552+ loc, invMassType, invMass, identityMatrix);
553+ auto randomOp = rewriter.create <enzyme::RandomOp>(
554+ loc, TypeRange{rngState.getType (), positionType}, rngState,
555+ zeroConst, oneConst,
556+ enzyme::RngDistributionAttr::get (
557+ rewriter.getContext (), enzyme::RngDistribution::NORMAL));
558+ rng1 = randomOp.getOutputRngState ();
559+ Value eps = randomOp.getResult ();
560+ p0 = rewriter.create <enzyme::DotOp>(
561+ loc, positionType, massMatrixSqrt, eps,
562+ /* lhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
563+ /* rhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
564+ /* lhs_contracting_dimensions=*/
565+ rewriter.getDenseI64ArrayAttr ({1 }),
566+ /* rhs_contracting_dimensions=*/
567+ rewriter.getDenseI64ArrayAttr ({0 }));
568+ }
519569 } else {
520- auto randomOp = enzyme::RandomOp::create (
521- rewriter, loc, TypeRange{rngState.getType (), positionType},
522- rngState, zeroConst, oneConst,
570+ // Assume identity mass matrix: p0 ~ N(0, I)
571+ auto randomOp = rewriter.create <enzyme::RandomOp>(
572+ loc, TypeRange{rngState.getType (), positionType}, rngState,
573+ zeroConst, oneConst,
523574 enzyme::RngDistributionAttr::get (
524575 rewriter.getContext (), enzyme::RngDistribution::NORMAL));
525576 rng1 = randomOp.getOutputRngState ();
@@ -531,19 +582,41 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
531582 rewriter, loc, tensorType,
532583 DenseElementsAttr::get (tensorType, rewriter.getF64FloatAttr (0.5 )));
533584
534- // 4. Compute initial kinetic energy K0 = 0.5 * p^T * M^{-1} * p
585+ // 4. Compute initial kinetic energy K0 = 0.5 * p^T * M^-1 * p
535586 Value K0;
536- if (mass) {
537- auto MInvP0 = enzyme::CholeskySolveOp::create (rewriter, loc,
538- positionType, mass, p0);
539- auto p0DotMInvP =
540- enzyme::DotOp::create (rewriter, loc, tensorType, p0, MInvP0);
587+ if (invMass) {
588+ auto invMassType = cast<RankedTensorType>(invMass.getType ());
589+ Value invMassP0;
590+
591+ if (invMassType.getRank () == 1 ) {
592+ invMassP0 = rewriter.create <arith::MulFOp>(loc, invMass, p0);
593+ } else {
594+ invMassP0 = rewriter.create <enzyme::DotOp>(
595+ loc, positionType, invMass, p0,
596+ /* lhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
597+ /* rhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
598+ /* lhs_contracting_dimensions=*/ rewriter.getDenseI64ArrayAttr ({1 }),
599+ /* rhs_contracting_dimensions=*/
600+ rewriter.getDenseI64ArrayAttr ({0 }));
601+ }
602+
603+ auto p0DotInvMassP0 = rewriter.create <enzyme::DotOp>(
604+ loc, tensorType, p0, invMassP0,
605+ /* lhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
606+ /* rhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
607+ /* lhs_contracting_dimensions=*/ rewriter.getDenseI64ArrayAttr ({0 }),
608+ /* rhs_contracting_dimensions=*/ rewriter.getDenseI64ArrayAttr ({0 }));
541609 K0 = conditionalDump (
542610 rewriter, loc,
543- arith::MulFOp::create (rewriter, loc, halfConst, p0DotMInvP ),
611+ rewriter. create < arith::MulFOp>( loc, halfConst, p0DotInvMassP0 ),
544612 " HMC: initial kinetic energy K0" );
545613 } else {
546- auto p0DotP0 = enzyme::DotOp::create (rewriter, loc, tensorType, p0, p0);
614+ auto p0DotP0 = rewriter.create <enzyme::DotOp>(
615+ loc, tensorType, p0, p0,
616+ /* lhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
617+ /* rhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
618+ /* lhs_contracting_dimensions=*/ rewriter.getDenseI64ArrayAttr ({0 }),
619+ /* rhs_contracting_dimensions=*/ rewriter.getDenseI64ArrayAttr ({0 }));
547620 K0 = conditionalDump (
548621 rewriter, loc,
549622 arith::MulFOp::create (rewriter, loc, halfConst, p0DotP0),
@@ -647,11 +720,22 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
647720 rewriter, loc, arith::SubFOp::create (rewriter, loc, p, deltaP1),
648721 " Leapfrog: momentum p(t + eps/2)" );
649722
650- // 6.2 Full step on position: q += eps * M^{-1} * p1
723+ // 6.2 Full step on position: q += eps * M^-1 * p1
651724 Value v1;
652- if (mass) {
653- v1 = enzyme::CholeskySolveOp::create (rewriter, loc, positionType, mass,
654- p1);
725+ if (invMass) {
726+ auto invMassType = cast<RankedTensorType>(invMass.getType ());
727+
728+ if (invMassType.getRank () == 1 ) {
729+ v1 = rewriter.create <arith::MulFOp>(loc, invMass, p1);
730+ } else {
731+ v1 = rewriter.create <enzyme::DotOp>(
732+ loc, positionType, invMass, p1,
733+ /* lhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
734+ /* rhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
735+ /* lhs_contracting_dimensions=*/ rewriter.getDenseI64ArrayAttr ({1 }),
736+ /* rhs_contracting_dimensions=*/
737+ rewriter.getDenseI64ArrayAttr ({0 }));
738+ }
655739 } else {
656740 v1 = p1;
657741 }
@@ -738,19 +822,41 @@ struct ProbProgPass : public enzyme::impl::ProbProgPassBase<ProbProgPass> {
738822 rewriter, loc, arith::NegFOp::create (rewriter, loc, weight1),
739823 " HMC: final potential energy U1" );
740824
741- // K1 = 0.5 * pL^T * M^{-1} * pL
825+ // K1 = 0.5 * pL^T * M^-1 * pL
742826 Value K1;
743- if (mass) {
744- auto MInvPL = enzyme::CholeskySolveOp::create (rewriter, loc,
745- positionType, mass, pL);
746- auto pLDotMInvPL =
747- enzyme::DotOp::create (rewriter, loc, tensorType, pL, MInvPL);
827+ if (invMass) {
828+ auto invMassType = cast<RankedTensorType>(invMass.getType ());
829+ Value invMassPL;
830+
831+ if (invMassType.getRank () == 1 ) {
832+ invMassPL = rewriter.create <arith::MulFOp>(loc, invMass, pL);
833+ } else {
834+ invMassPL = rewriter.create <enzyme::DotOp>(
835+ loc, positionType, invMass, pL,
836+ /* lhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
837+ /* rhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
838+ /* lhs_contracting_dimensions=*/ rewriter.getDenseI64ArrayAttr ({1 }),
839+ /* rhs_contracting_dimensions=*/
840+ rewriter.getDenseI64ArrayAttr ({0 }));
841+ }
842+
843+ auto pLDotInvMassPL = rewriter.create <enzyme::DotOp>(
844+ loc, tensorType, pL, invMassPL,
845+ /* lhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
846+ /* rhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
847+ /* lhs_contracting_dimensions=*/ rewriter.getDenseI64ArrayAttr ({0 }),
848+ /* rhs_contracting_dimensions=*/ rewriter.getDenseI64ArrayAttr ({0 }));
748849 K1 = conditionalDump (
749850 rewriter, loc,
750- arith::MulFOp::create (rewriter, loc, halfConst, pLDotMInvPL ),
851+ rewriter. create < arith::MulFOp>( loc, halfConst, pLDotInvMassPL ),
751852 " HMC: final kinetic energy K1" );
752853 } else {
753- auto pLDotPL = enzyme::DotOp::create (rewriter, loc, tensorType, pL, pL);
854+ auto pLDotPL = rewriter.create <enzyme::DotOp>(
855+ loc, tensorType, pL, pL,
856+ /* lhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
857+ /* rhs_batching_dimensions=*/ rewriter.getDenseI64ArrayAttr ({}),
858+ /* lhs_contracting_dimensions=*/ rewriter.getDenseI64ArrayAttr ({0 }),
859+ /* rhs_contracting_dimensions=*/ rewriter.getDenseI64ArrayAttr ({0 }));
754860 K1 = conditionalDump (
755861 rewriter, loc,
756862 arith::MulFOp::create (rewriter, loc, halfConst, pLDotPL),
0 commit comments