From 7fa17ce29cd1c987df2e6b77e5453cd6f44e2db1 Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sat, 28 Jun 2025 21:42:46 -0500 Subject: [PATCH 1/2] active --> const --- enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 34 ++++++++++++++++--- .../test/MLIR/ReverseMode/canonicalize.mlir | 8 +++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 88de8e0c96c3..074375c25060 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -551,15 +551,35 @@ class ReverseRetOpt final : public OpRewritePattern { switch (val) { case Activity::enzyme_active: - if (!res.use_empty()) { - outs_args.push_back(res); - out_ty.push_back(res.getType()); - newRetActivityArgs.push_back(iattr); - } else { + if (res.use_empty()) { changed = true; auto new_activenn = ActivityAttr::get(rewriter.getContext(), Activity::enzyme_activenoneed); newRetActivityArgs.push_back(new_activenn); + } else { + int in_idx = 0; + for (auto act : inpActivity) { + auto v = cast(act).getValue(); + in_idx += + (v == Activity::enzyme_dup || v == Activity::enzyme_dupnoneed) + ? 2 + : 1; + } + in_idx += out_idx; + auto dres = uop.getInputs()[in_idx]; + + if (matchPattern(dres, m_Zero()) || + matchPattern(dres, m_AnyZeroFloat())) { + changed = true; + auto new_const = ActivityAttr::get(rewriter.getContext(), + Activity::enzyme_const); + newRetActivityArgs.push_back(new_const); + } else { + newRetActivityArgs.push_back(iattr); + } + + outs_args.push_back(res); + out_ty.push_back(res.getType()); } break; @@ -678,6 +698,10 @@ class ReverseRetOpt final : public OpRewritePattern { } else if (new_val == Activity::enzyme_constnoneed && old_val == Activity::enzyme_const) { ++oldIdx; // skip const primal + } else if (new_val == Activity::enzyme_const && + old_val == Activity::enzyme_active) { + uop.getOutputs()[oldIdx++].replaceAllUsesWith( + newOp.getOutputs()[newIdx++]); } } } diff --git a/enzyme/test/MLIR/ReverseMode/canonicalize.mlir b/enzyme/test/MLIR/ReverseMode/canonicalize.mlir index 48816d72b84d..853c56644cc0 100644 --- a/enzyme/test/MLIR/ReverseMode/canonicalize.mlir +++ b/enzyme/test/MLIR/ReverseMode/canonicalize.mlir @@ -41,4 +41,12 @@ module { // CHECK: enzyme.autodiff @square2(%arg0, %arg1, %arg2, %arg3){{.*}}activity = [#enzyme, #enzyme]{{.*}}ret_activity = [#enzyme, #enzyme]{{.*}} return %cst : f32 } + + // Test 5: active -> const for ret_activity (iff derivative is 0) + func.func @test5(%x: f32, %y: f32, %dr0: f32) -> (f32,f32,f32,f32) { + %cst = arith.constant 0.0000e+00 : f32 + %r:4 = enzyme.autodiff @square2(%x,%y,%dr0,%cst) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme, #enzyme] } : (f32,f32,f32,f32) -> (f32,f32,f32,f32) + // CHECK: %{{.*}} = enzyme.autodiff @square2(%arg0, %arg1, %arg2, %cst){{.*}}activity = [#enzyme, #enzyme]{{.*}}ret_activity = [#enzyme, #enzyme]{{.*}} + return %r#0,%r#1,%r#2,%r#3 : f32,f32,f32,f32 + } } From a9945e68cddc22a5e8f97054e7d3b0c42800649d Mon Sep 17 00:00:00 2001 From: sbrantq Date: Sat, 28 Jun 2025 22:01:19 -0500 Subject: [PATCH 2/2] activenoneed --> constnoneed --- enzyme/Enzyme/MLIR/Dialect/Ops.cpp | 27 +++++++++++++++++-- .../test/MLIR/ReverseMode/canonicalize.mlir | 8 ++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 074375c25060..169c34e59c42 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -541,7 +541,6 @@ class ReverseRetOpt final : public OpRewritePattern { // skip primal return if (val == Activity::enzyme_constnoneed || - val == Activity::enzyme_activenoneed || val == Activity::enzyme_dupnoneed) { newRetActivityArgs.push_back(iattr); continue; @@ -603,7 +602,31 @@ class ReverseRetOpt final : public OpRewritePattern { newRetActivityArgs.push_back(iattr); break; - case Activity::enzyme_activenoneed: + case Activity::enzyme_activenoneed: { + int in_idx = 0; + for (auto act : inpActivity) { + auto v = cast(act).getValue(); + in_idx += + (v == Activity::enzyme_dup || v == Activity::enzyme_dupnoneed) + ? 2 + : 1; + } + in_idx += out_idx; + + auto dres = uop.getInputs()[in_idx]; + + if (matchPattern(dres, m_Zero()) || + matchPattern(dres, m_AnyZeroFloat())) { + changed = true; + auto new_constnn = ActivityAttr::get(rewriter.getContext(), + Activity::enzyme_constnoneed); + newRetActivityArgs.push_back(new_constnn); + } else { + newRetActivityArgs.push_back(iattr); + } + + continue; + } case Activity::enzyme_constnoneed: case Activity::enzyme_dupnoneed: break; diff --git a/enzyme/test/MLIR/ReverseMode/canonicalize.mlir b/enzyme/test/MLIR/ReverseMode/canonicalize.mlir index 853c56644cc0..33676ec3d7ac 100644 --- a/enzyme/test/MLIR/ReverseMode/canonicalize.mlir +++ b/enzyme/test/MLIR/ReverseMode/canonicalize.mlir @@ -49,4 +49,12 @@ module { // CHECK: %{{.*}} = enzyme.autodiff @square2(%arg0, %arg1, %arg2, %cst){{.*}}activity = [#enzyme, #enzyme]{{.*}}ret_activity = [#enzyme, #enzyme]{{.*}} return %r#0,%r#1,%r#2,%r#3 : f32,f32,f32,f32 } + + // Test 6: active -> activenoneed/const -> constnoneed for ret_activity + func.func @test6(%x: f32, %y: f32, %dr0: f32) -> (f32,f32,f32) { + %cst = arith.constant 0.0000e+00 : f32 + %r:4 = enzyme.autodiff @square2(%x,%y,%dr0,%cst) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme, #enzyme] } : (f32,f32,f32,f32) -> (f32,f32,f32,f32) + // CHECK: %{{.*}} = enzyme.autodiff @square2(%arg0, %arg1, %arg2, %cst){{.*}}activity = [#enzyme, #enzyme]{{.*}}ret_activity = [#enzyme, #enzyme]{{.*}} + return %r#0,%r#2,%r#3 : f32,f32,f32 + } }