From da2ca8cc9f8a305a0bcc3d3ac61175402138babb Mon Sep 17 00:00:00 2001
From: Andrew Lenharth <andrew@lenharth.org>
Date: Tue, 21 May 2024 22:57:18 -0500
Subject: [PATCH] [FIRRTL] Register reset elimination based on invalid can look
 through nodes. (#7069)

This converts wires into nodes when there is one write to the wire and it dominates the reads.  By converting to nodes, this pass does not have to worry about symbols, references, or annotations.  Those are just copied to the node.
---
 lib/Dialect/FIRRTL/Transforms/InferResets.cpp | 2 +-
 lib/Dialect/FIRRTL/Transforms/SFCCompat.cpp   | 2 +-
 test/Dialect/FIRRTL/SFCTests/data-taps.fir    | 2 ++
 test/Dialect/FIRRTL/SFCTests/dedup.fir        | 2 +-
 test/Dialect/FIRRTL/sfc-compat.mlir           | 5 +++--
 5 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/lib/Dialect/FIRRTL/Transforms/InferResets.cpp b/lib/Dialect/FIRRTL/Transforms/InferResets.cpp
index 951813353dfb..7c4299183abc 100644
--- a/lib/Dialect/FIRRTL/Transforms/InferResets.cpp
+++ b/lib/Dialect/FIRRTL/Transforms/InferResets.cpp
@@ -665,7 +665,7 @@ static bool getDeclName(Value value, SmallString<32> &string) {
             op.getPortName(cast<OpResult>(value).getResultNumber()).getValue();
         return true;
       })
-      .Case<WireOp, RegOp, RegResetOp>([&](auto op) {
+      .Case<WireOp, NodeOp, RegOp, RegResetOp>([&](auto op) {
         string += op.getName();
         return true;
       })
diff --git a/lib/Dialect/FIRRTL/Transforms/SFCCompat.cpp b/lib/Dialect/FIRRTL/Transforms/SFCCompat.cpp
index e8bc2b20ab2f..80fa52b31cb9 100644
--- a/lib/Dialect/FIRRTL/Transforms/SFCCompat.cpp
+++ b/lib/Dialect/FIRRTL/Transforms/SFCCompat.cpp
@@ -75,7 +75,7 @@ void SFCCompatPass::runOnOperation() {
     // If the `RegResetOp` has an invalidated initialization and we
     // are not running FART, then replace it with a `RegOp`.
     if (!fullAsyncResetExists &&
-        walkDrivers(reg.getResetValue(), true, false, false,
+        walkDrivers(reg.getResetValue(), true, true, false,
                     [](FieldRef dst, FieldRef src) {
                       return src.isa<InvalidValueOp>();
                     })) {
diff --git a/test/Dialect/FIRRTL/SFCTests/data-taps.fir b/test/Dialect/FIRRTL/SFCTests/data-taps.fir
index 4c1171a3aba7..535a9b73be4f 100644
--- a/test/Dialect/FIRRTL/SFCTests/data-taps.fir
+++ b/test/Dialect/FIRRTL/SFCTests/data-taps.fir
@@ -200,6 +200,8 @@ circuit Top : %[[
     io.d <= d
 
     ; CHECK:      module Top
+    ; TODO: fix having constants carry names
+    ; CHECK:        wire inv = 1'h0
     ; CHECK:        io_b = Top.foo.f_probe;
     ; CHECK-NEXT:   io_c = Top.foo.g_probe;
     ; CHECK-NEXT:   io_d = inv;
diff --git a/test/Dialect/FIRRTL/SFCTests/dedup.fir b/test/Dialect/FIRRTL/SFCTests/dedup.fir
index cc0368676c58..e3f6bd2dcc22 100644
--- a/test/Dialect/FIRRTL/SFCTests/dedup.fir
+++ b/test/Dialect/FIRRTL/SFCTests/dedup.fir
@@ -905,7 +905,7 @@ circuit Top18 : %[[
   ; CHECK: module private @A(
   module A :
     output x: UInt<1>
-    ; CHECK-NEXT: firrtl.wire
+    ; CHECK: firrtl.wire
     ; CHECK-SAME: [{class = "firrtl.transforms.DontTouchAnnotation"}]
     wire b: UInt<1>
     b is invalid
diff --git a/test/Dialect/FIRRTL/sfc-compat.mlir b/test/Dialect/FIRRTL/sfc-compat.mlir
index 5613a140d963..92e0fefabcee 100644
--- a/test/Dialect/FIRRTL/sfc-compat.mlir
+++ b/test/Dialect/FIRRTL/sfc-compat.mlir
@@ -93,13 +93,14 @@ firrtl.circuit "SFCCompatTests" {
     firrtl.connect %q, %r : !firrtl.uint<1>, !firrtl.uint<1>
   }
 
-  // A regreset invalid value should NOT propagate through a node.
+  // A regreset invalid value should propagate through a node.
+  // Change from SFC behavior.
   firrtl.module @InvalidNode(in %clock: !firrtl.clock, in %reset: !firrtl.uint<1>, in %d: !firrtl.uint<8>, out %q: !firrtl.uint<8>) {
     %inv = firrtl.wire  : !firrtl.uint<8>
     %invalid_ui8 = firrtl.invalidvalue : !firrtl.uint<8>
     firrtl.connect %inv, %invalid_ui8 : !firrtl.uint<8>, !firrtl.uint<8>
     %_T = firrtl.node %inv  : !firrtl.uint<8>
-    // CHECK: firrtl.regreset %clock
+    // CHECK: firrtl.reg %clock
     %r = firrtl.regreset %clock, %reset, %_T  : !firrtl.clock, !firrtl.uint<1>, !firrtl.uint<8>, !firrtl.uint<8>
     firrtl.connect %r, %d : !firrtl.uint<8>, !firrtl.uint<8>
     firrtl.connect %q, %r : !firrtl.uint<8>, !firrtl.uint<8>