From 0e31fd715e1ccd1cdf69263160ffcf6a053413b4 Mon Sep 17 00:00:00 2001 From: Nandor Licker Date: Tue, 21 Nov 2023 17:49:56 +0200 Subject: [PATCH] [InstanceGraph] Remove the module lookup helper (#6425) The instance graph is transitioning towards supporting multiple possible targets for instance operations. As a consequence, it can no longer assume that there are instance operations with a single target and it will expose a generic interface returning a list of targets. This PR removes the lookup helper from the graph and instead provides the users (`firrtl::InstanceOp` in particular) with helpers to fetch the unique referenced instance. --- .../Dialect/FIRRTL/FIRRTLDeclarations.td | 17 +++++++++++++- .../Dialect/FIRRTL/FIRRTLOpInterfaces.td | 2 +- include/circt/Dialect/FIRRTL/FIRRTLOps.h | 1 + include/circt/Support/InstanceGraph.h | 12 ++-------- lib/Conversion/FIRRTLToHW/LowerToHW.cpp | 15 +++++++------ lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp | 2 +- lib/Dialect/FIRRTL/FIRRTLOps.cpp | 2 +- lib/Dialect/FIRRTL/FIRRTLReductions.cpp | 17 +++++++------- .../FIRRTL/Transforms/AddSeqMemPorts.cpp | 2 +- .../FIRRTL/Transforms/CheckCombLoops.cpp | 3 +-- lib/Dialect/FIRRTL/Transforms/Dedup.cpp | 22 +++++++++---------- .../FIRRTL/Transforms/ExtractInstances.cpp | 2 +- lib/Dialect/FIRRTL/Transforms/IMConstProp.cpp | 9 ++++---- .../FIRRTL/Transforms/IMDeadCodeElim.cpp | 8 +++---- lib/Dialect/FIRRTL/Transforms/InferResets.cpp | 9 ++++---- lib/Dialect/FIRRTL/Transforms/InferWidths.cpp | 2 +- .../FIRRTL/Transforms/LowerAnnotations.cpp | 9 ++++---- .../FIRRTL/Transforms/LowerClasses.cpp | 17 ++++++-------- lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp | 2 +- lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp | 2 +- .../FIRRTL/Transforms/PrefixModules.cpp | 3 +-- .../FIRRTL/Transforms/ResolvePaths.cpp | 2 +- lib/Dialect/HW/HWReductions.cpp | 6 +++-- lib/Dialect/Ibis/Transforms/IbisTunneling.cpp | 4 ++-- lib/LogicalEquivalence/LogicExporter.cpp | 4 ++-- 25 files changed, 89 insertions(+), 85 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td b/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td index f3ed5f2d45fe..c26e1361a4e7 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td +++ b/include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td @@ -129,6 +129,21 @@ def InstanceOp : HardwareDeclOp<"instance", [ /// the new InstanceOp to the same location. InstanceOp cloneAndInsertPorts(ArrayRef> ports); + //===------------------------------------------------------------------===// + // Instance graph methods + //===------------------------------------------------------------------===// + + // Quick lookup of the referenced module using the instance graph. + template + T getReferencedModule(::circt::igraph::InstanceGraph &instanceGraph) { + auto moduleNameAttr = getModuleNameAttr().getAttr(); + auto *node = instanceGraph.lookup(moduleNameAttr); + if (!node) + return nullptr; + Operation *moduleOp = node->getModule(); + return dyn_cast_or_null(moduleOp); + } + //===------------------------------------------------------------------===// // PortList Methods //===------------------------------------------------------------------===// @@ -598,7 +613,7 @@ def ObjectOp : FIRRTLOp<"object", [ ParentOneOf<["firrtl::FModuleOp, firrtl::ClassOp"]>, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods - auto getReferencedModule(InstanceOpInterface op) { - return cast(getReferencedModuleImpl(op).getOperation()); - } - /// Check if child is instantiated by a parent. bool isAncestor(ModuleOpInterface child, ModuleOpInterface parent); @@ -330,8 +322,8 @@ static T &operator<<(T &os, const InstancePath &path) { return path.print(os); } -/// A data structure that caches and provides absolute paths to module instances -/// in the IR. +/// A data structure that caches and provides absolute paths to module +/// instances in the IR. struct InstancePathCache { /// The instance graph of the IR. InstanceGraph &instanceGraph; diff --git a/lib/Conversion/FIRRTLToHW/LowerToHW.cpp b/lib/Conversion/FIRRTLToHW/LowerToHW.cpp index 1d963f1157c3..22b1497a8d32 100644 --- a/lib/Conversion/FIRRTLToHW/LowerToHW.cpp +++ b/lib/Conversion/FIRRTLToHW/LowerToHW.cpp @@ -212,7 +212,7 @@ struct CircuitLoweringState { CircuitLoweringState(CircuitOp circuitOp, bool enableAnnotationWarning, bool emitChiselAssertsAsSVA, - InstanceGraph *instanceGraph, NLATable *nlaTable) + InstanceGraph &instanceGraph, NLATable *nlaTable) : circuitOp(circuitOp), instanceGraph(instanceGraph), enableAnnotationWarning(enableAnnotationWarning), emitChiselAssertsAsSVA(emitChiselAssertsAsSVA), nlaTable(nlaTable) { @@ -235,7 +235,7 @@ struct CircuitLoweringState { // Figure out which module is the DUT and TestHarness. If there is no // module marked as the DUT, the top module is the DUT. If the DUT and the // test harness are the same, then there is no test harness. - testHarness = instanceGraph->getTopLevelModule(); + testHarness = instanceGraph.getTopLevelModule(); if (!dut) { dut = testHarness; testHarness = nullptr; @@ -282,7 +282,7 @@ struct CircuitLoweringState { // Returns false if the module is not instantiated by the DUT. bool isInDUT(igraph::ModuleOpInterface child) { if (auto parent = dyn_cast(*dut)) - return getInstanceGraph()->isAncestor(child, parent); + return getInstanceGraph().isAncestor(child, parent); return dut == child; } @@ -293,7 +293,7 @@ struct CircuitLoweringState { // Harness is not known. bool isInTestHarness(igraph::ModuleOpInterface mod) { return !isInDUT(mod); } - InstanceGraph *getInstanceGraph() { return instanceGraph; } + InstanceGraph &getInstanceGraph() { return instanceGraph; } /// Given a type, return the corresponding lowered type for the HW dialect. /// A wrapper to the FIRRTLUtils::lowerType, required to ensure safe addition @@ -316,7 +316,7 @@ struct CircuitLoweringState { /// Cache of module symbols. We need to test hirarchy-based properties to /// lower annotaitons. - InstanceGraph *instanceGraph; + InstanceGraph &instanceGraph; // Record the set of remaining annotation classes. This is used to warn only // once about any annotation class. @@ -548,7 +548,7 @@ void FIRRTLModuleLowering::runOnOperation() { // if lowering failed. CircuitLoweringState state( circuit, enableAnnotationWarning, emitChiselAssertsAsSVA, - &getAnalysis(), &getAnalysis()); + getAnalysis(), &getAnalysis()); SmallVector modulesToProcess; @@ -3062,7 +3062,8 @@ LogicalResult FIRRTLLowering::visitDecl(MemOp op) { LogicalResult FIRRTLLowering::visitDecl(InstanceOp oldInstance) { Operation *oldModule = - circuitState.getInstanceGraph()->getReferencedModule(oldInstance); + oldInstance.getReferencedModule(circuitState.getInstanceGraph()); + auto newModule = circuitState.getNewModule(oldModule); if (!newModule) { oldInstance->emitOpError("could not find module [") diff --git a/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp b/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp index 319ae03e4332..25e9421d540f 100644 --- a/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp @@ -244,7 +244,7 @@ firrtl::resolveEntities(TokenAnnoTarget path, CircuitOp circuit, ArrayRef component(path.component); if (auto instance = dyn_cast(ref.getOp())) { instances.push_back(instance); - auto target = cast(instance.getReferencedModule(symTbl)); + auto target = cast(instance.getReferencedOperation(symTbl)); if (component.empty()) { ref = OpAnnoTarget(target); } else if (component.front().isIndex) { diff --git a/lib/Dialect/FIRRTL/FIRRTLOps.cpp b/lib/Dialect/FIRRTL/FIRRTLOps.cpp index fea64bcecd0b..8443dc0a881a 100644 --- a/lib/Dialect/FIRRTL/FIRRTLOps.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLOps.cpp @@ -2919,7 +2919,7 @@ ClassLike ObjectOp::getReferencedClass(const SymbolTable &symbolTable) { return symbolTable.lookup(symRef.getLeafReference()); } -Operation *ObjectOp::getReferencedModule(const SymbolTable &symtbl) { +Operation *ObjectOp::getReferencedOperation(const SymbolTable &symtbl) { return getReferencedClass(symtbl); } diff --git a/lib/Dialect/FIRRTL/FIRRTLReductions.cpp b/lib/Dialect/FIRRTL/FIRRTLReductions.cpp index 8839229632fa..736cf47a8b7e 100644 --- a/lib/Dialect/FIRRTL/FIRRTLReductions.cpp +++ b/lib/Dialect/FIRRTL/FIRRTLReductions.cpp @@ -67,7 +67,7 @@ findInstantiatedModule(firrtl::InstanceOp instOp, ::detail::SymbolCache &symbols) { auto *tableOp = SymbolTable::getNearestSymbolTable(instOp); auto moduleOp = dyn_cast( - instOp.getReferencedModule(symbols.getSymbolTable(tableOp))); + instOp.getReferencedOperation(symbols.getSymbolTable(tableOp))); return moduleOp ? std::optional(moduleOp) : std::nullopt; } @@ -338,7 +338,7 @@ struct InstanceStubber : public OpReduction { auto *tableOp = SymbolTable::getNearestSymbolTable(op); op->walk([&](firrtl::InstanceOp instOp) { auto moduleOp = cast( - instOp.getReferencedModule(symbols.getSymbolTable(tableOp))); + instOp.getReferencedOperation(symbols.getSymbolTable(tableOp))); deadInsts.insert(instOp); if (llvm::all_of( symbols.getSymbolUserMap(tableOp).getUsers(moduleOp), @@ -385,7 +385,7 @@ struct InstanceStubber : public OpReduction { } auto *tableOp = SymbolTable::getNearestSymbolTable(instOp); auto moduleOp = cast( - instOp.getReferencedModule(symbols.getSymbolTable(tableOp))); + instOp.getReferencedOperation(symbols.getSymbolTable(tableOp))); nlaRemover.markNLAsInOperation(instOp); erasedInsts.insert(instOp); if (llvm::all_of( @@ -658,12 +658,12 @@ struct ExtmoduleInstanceRemover : public OpReduction { uint64_t match(firrtl::InstanceOp instOp) override { return isa( - instOp.getReferencedModule(symbols.getNearestSymbolTable(instOp))); + instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp))); } LogicalResult rewrite(firrtl::InstanceOp instOp) override { auto portInfo = - cast( - instOp.getReferencedModule(symbols.getNearestSymbolTable(instOp))) + cast(instOp.getReferencedOperation( + symbols.getNearestSymbolTable(instOp))) .getPorts(); ImplicitLocOpBuilder builder(instOp.getLoc(), instOp); SmallVector replacementWires; @@ -895,7 +895,8 @@ struct EagerInliner : public OpReduction { uint64_t match(firrtl::InstanceOp instOp) override { auto *tableOp = SymbolTable::getNearestSymbolTable(instOp); - auto moduleOp = instOp.getReferencedModule(symbols.getSymbolTable(tableOp)); + auto *moduleOp = + instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)); if (!isa(moduleOp)) return 0; return symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1; @@ -921,7 +922,7 @@ struct EagerInliner : public OpReduction { } auto *tableOp = SymbolTable::getNearestSymbolTable(instOp); auto moduleOp = cast( - instOp.getReferencedModule(symbols.getSymbolTable(tableOp))); + instOp.getReferencedOperation(symbols.getSymbolTable(tableOp))); for (auto &op : llvm::make_early_inc_range(*moduleOp.getBodyBlock())) { op.remove(); builder.insert(&op); diff --git a/lib/Dialect/FIRRTL/Transforms/AddSeqMemPorts.cpp b/lib/Dialect/FIRRTL/Transforms/AddSeqMemPorts.cpp index 0442be7bd07a..44b4256c3b0f 100644 --- a/lib/Dialect/FIRRTL/Transforms/AddSeqMemPorts.cpp +++ b/lib/Dialect/FIRRTL/Transforms/AddSeqMemPorts.cpp @@ -194,7 +194,7 @@ LogicalResult AddSeqMemPortsPass::processModule(FModuleOp module, bool isDUT) { for (auto &op : llvm::make_early_inc_range(*module.getBodyBlock())) { if (auto inst = dyn_cast(op)) { - auto submodule = instanceGraph->getReferencedModule(inst); + auto submodule = inst.getReferencedModule(*instanceGraph); auto subMemInfoIt = memInfoMap.find(submodule); // If there are no extra ports, we don't have to do anything. diff --git a/lib/Dialect/FIRRTL/Transforms/CheckCombLoops.cpp b/lib/Dialect/FIRRTL/Transforms/CheckCombLoops.cpp index 45c132727080..2dd2dab3f7c0 100644 --- a/lib/Dialect/FIRRTL/Transforms/CheckCombLoops.cpp +++ b/lib/Dialect/FIRRTL/Transforms/CheckCombLoops.cpp @@ -376,8 +376,7 @@ class DiscoverLoops { if (auto inst = dyn_cast_or_null(ref.getDefiningOp())) { auto res = cast(ref.getValue()); auto portNum = res.getResultNumber(); - auto refMod = - dyn_cast_or_null(*instanceGraph.getReferencedModule(inst)); + auto refMod = inst.getReferencedModule(instanceGraph); if (!refMod) return; FieldRef modArg(refMod.getArgument(portNum), ref.getFieldID()); diff --git a/lib/Dialect/FIRRTL/Transforms/Dedup.cpp b/lib/Dialect/FIRRTL/Transforms/Dedup.cpp index a57535c59bca..dc244a26dd3b 100644 --- a/lib/Dialect/FIRRTL/Transforms/Dedup.cpp +++ b/lib/Dialect/FIRRTL/Transforms/Dedup.cpp @@ -620,20 +620,20 @@ struct Equivalence { LogicalResult check(InFlightDiagnostic &diag, InstanceOp a, InstanceOp b) { auto aName = a.getModuleNameAttr().getAttr(); auto bName = b.getModuleNameAttr().getAttr(); + if (aName == bName) + return success(); + // If the modules instantiate are different we will want to know why the // sub module did not dedupliate. This code recursively checks the child // module. - if (aName != bName) { - auto aModule = instanceGraph.getReferencedModule(a); - auto bModule = instanceGraph.getReferencedModule(b); - // Create a new error for the submodule. - diag.attachNote(std::nullopt) - << "in instance " << a.getNameAttr() << " of " << aName - << ", and instance " << b.getNameAttr() << " of " << bName; - check(diag, aModule, bModule); - return failure(); - } - return success(); + auto aModule = a.getReferencedModule(instanceGraph); + auto bModule = b.getReferencedModule(instanceGraph); + // Create a new error for the submodule. + diag.attachNote(std::nullopt) + << "in instance " << a.getNameAttr() << " of " << aName + << ", and instance " << b.getNameAttr() << " of " << bName; + check(diag, aModule, bModule); + return failure(); } // NOLINTNEXTLINE(misc-no-recursion) diff --git a/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp b/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp index 27963d7f65f9..ce37a38f3f07 100644 --- a/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp +++ b/lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp @@ -278,7 +278,7 @@ void ExtractInstancesPass::collectAnnos() { // Gather the annotations on instances to be extracted. circuit.walk([&](InstanceOp inst) { SmallVector instAnnos; - Operation *module = instanceGraph->getReferencedModule(inst); + Operation *module = inst.getReferencedModule(*instanceGraph); // Module-level annotations. auto it = annotatedModules.find(module); diff --git a/lib/Dialect/FIRRTL/Transforms/IMConstProp.cpp b/lib/Dialect/FIRRTL/Transforms/IMConstProp.cpp index d4d33d27d89b..bc19ce417e1d 100644 --- a/lib/Dialect/FIRRTL/Transforms/IMConstProp.cpp +++ b/lib/Dialect/FIRRTL/Transforms/IMConstProp.cpp @@ -565,7 +565,7 @@ void IMConstPropPass::markInvalidValueOp(InvalidValueOp invalid) { /// enclosing block is marked live. This sets up the def-use edges for ports. void IMConstPropPass::markInstanceOp(InstanceOp instance) { // Get the module being reference or a null pointer if this is an extmodule. - Operation *op = instanceGraph->getReferencedModule(instance); + Operation *op = instance.getReferencedModule(*instanceGraph); // If this is an extmodule, just remember that any results and inouts are // overdefined. @@ -721,12 +721,11 @@ void IMConstPropPass::visitConnectLike(FConnectLike connect, if (auto instance = dest.getDefiningOp()) { // Update the dest, when its an instance op. mergeLatticeValue(fieldRefDestConnected, srcValue); - auto module = - dyn_cast(*instanceGraph->getReferencedModule(instance)); - if (!module) + auto mod = instance.getReferencedModule(*instanceGraph); + if (!mod) return; - BlockArgument modulePortVal = module.getArgument(dest.getResultNumber()); + BlockArgument modulePortVal = mod.getArgument(dest.getResultNumber()); return mergeLatticeValue( FieldRef(modulePortVal, fieldRefDestConnected.getFieldID()), diff --git a/lib/Dialect/FIRRTL/Transforms/IMDeadCodeElim.cpp b/lib/Dialect/FIRRTL/Transforms/IMDeadCodeElim.cpp index 7a986157fb16..59bc073bea2c 100644 --- a/lib/Dialect/FIRRTL/Transforms/IMDeadCodeElim.cpp +++ b/lib/Dialect/FIRRTL/Transforms/IMDeadCodeElim.cpp @@ -128,8 +128,7 @@ struct IMDeadCodeElimPass : public IMDeadCodeElimBase { void IMDeadCodeElimPass::visitInstanceOp(InstanceOp instance) { markBlockUndeletable(instance); - auto module = - dyn_cast(*instanceGraph->getReferencedModule(instance)); + auto module = instance.getReferencedModule(*instanceGraph); if (!module) return; @@ -198,7 +197,7 @@ void IMDeadCodeElimPass::visitUser(Operation *op) { void IMDeadCodeElimPass::markInstanceOp(InstanceOp instance) { // Get the module being referenced. - Operation *op = instanceGraph->getReferencedModule(instance); + Operation *op = instance.getReferencedModule(*instanceGraph); // If this is an extmodule, just remember that any inputs and inouts are // alive. @@ -487,8 +486,7 @@ void IMDeadCodeElimPass::visitValue(Value value) { if (auto instance = value.getDefiningOp()) { auto instanceResult = value.cast(); // Update the src, when it's an instance op. - auto module = - dyn_cast(*instanceGraph->getReferencedModule(instance)); + auto module = instance.getReferencedModule(*instanceGraph); // Propagate liveness only when a port is output. if (!module || module.getPortDirection(instanceResult.getResultNumber()) == diff --git a/lib/Dialect/FIRRTL/Transforms/InferResets.cpp b/lib/Dialect/FIRRTL/Transforms/InferResets.cpp index 4bd1df91df56..4b84e468ebaf 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferResets.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferResets.cpp @@ -862,7 +862,7 @@ void InferResetsPass::traceResets(CircuitOp circuit) { /// instance's port values with the target module's port values. void InferResetsPass::traceResets(InstanceOp inst) { // Lookup the referenced module. Nothing to do if its an extmodule. - auto module = dyn_cast(*instanceGraph->getReferencedModule(inst)); + auto module = inst.getReferencedModule(*instanceGraph); if (!module) return; LLVM_DEBUG(llvm::dbgs() << "Visiting instance " << inst.getName() << "\n"); @@ -1114,8 +1114,8 @@ LogicalResult InferResetsPass::updateReset(ResetNetwork net, ResetKind kind) { if (auto blockArg = dyn_cast(value)) moduleWorklist.insert(blockArg.getOwner()->getParentOp()); else if (auto instOp = value.getDefiningOp()) { - if (auto extmodule = dyn_cast( - *instanceGraph->getReferencedModule(instOp))) + if (auto extmodule = + instOp.getReferencedModule(*instanceGraph)) extmoduleWorklist.insert({extmodule, instOp}); } else if (auto uncast = value.getDefiningOp()) { uncast.replaceAllUsesWith(uncast.getInput()); @@ -1742,8 +1742,7 @@ void InferResetsPass::implementAsyncReset(Operation *op, FModuleOp module, // Lookup the reset domain of the instantiated module. If there is no // reset domain associated with that module, or the module is explicitly // marked as being in no domain, simply skip. - auto refModule = - dyn_cast(*instanceGraph->getReferencedModule(instOp)); + auto refModule = instOp.getReferencedModule(*instanceGraph); if (!refModule) return; auto domainIt = domains.find(refModule); diff --git a/lib/Dialect/FIRRTL/Transforms/InferWidths.cpp b/lib/Dialect/FIRRTL/Transforms/InferWidths.cpp index 3e92fbff9d1b..004548986223 100644 --- a/lib/Dialect/FIRRTL/Transforms/InferWidths.cpp +++ b/lib/Dialect/FIRRTL/Transforms/InferWidths.cpp @@ -1635,7 +1635,7 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) { // Handle instances of other modules. .Case([&](auto op) { - auto refdModule = op.getReferencedModule(symtbl); + auto refdModule = op.getReferencedOperation(symtbl); auto module = dyn_cast(&*refdModule); if (!module) { auto diag = mlir::emitError(op.getLoc()); diff --git a/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp b/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp index 72befb2e1d5b..94bfa43eadb8 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp @@ -808,8 +808,8 @@ LogicalResult LowerAnnotationsPass::solveWiringProblems(ApplyState &state) { while (!sources.empty() && !sinks.empty()) { if (sources.top() != sinks.top()) break; - auto newLCA = sources.top(); - lca = cast(instanceGraph.getReferencedModule(newLCA)); + auto newLCA = cast(*sources.top()); + lca = cast(newLCA.getReferencedModule(instanceGraph)); sources = sources.dropFront(); sinks = sinks.dropFront(); } @@ -885,8 +885,9 @@ LogicalResult LowerAnnotationsPass::solveWiringProblems(ApplyState &state) { auto addPorts = [&](igraph::InstancePath insts, Value val, Type tpe, Direction dir) { StringRef name, instName; - for (auto inst : llvm::reverse(insts)) { - auto mod = instanceGraph.getReferencedModule(inst); + for (auto instNode : llvm::reverse(insts)) { + auto inst = cast(*instNode); + auto mod = inst.getReferencedModule(instanceGraph); if (name.empty()) { if (problem.newNameHint.empty()) name = state.getNamespace(mod).newName( diff --git a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp index 2c441f9e5315..e70babee765c 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp @@ -100,7 +100,6 @@ struct LowerClassesPass : public LowerClassesBase { // Update Object instantiations in a FIRRTL Module or OM Class. LogicalResult updateInstances(Operation *op, InstanceGraph &instanceGraph, - const SymbolTable &symbolTable, const LoweringState &state); // Convert to OM ops and types in Classes or Modules. @@ -384,8 +383,7 @@ void LowerClassesPass::runOnOperation() { // Update Object creation ops in Classes or Modules in parallel. if (failed( mlir::failableParallelForEach(ctx, objectContainers, [&](auto *op) { - return updateInstances(op, instanceGraph, symbolTable, - loweringState); + return updateInstances(op, instanceGraph, loweringState); }))) return signalPassFailure(); @@ -686,7 +684,7 @@ updateObjectInClass(firrtl::ObjectOp firrtlObject, // Module. static LogicalResult updateInstanceInClass(InstanceOp firrtlInstance, hw::HierPathOp hierPath, - const SymbolTable &symbolTable, + InstanceGraph &instanceGraph, SmallVectorImpl &opsToErase) { // Set the insertion point right before the instance op. @@ -726,7 +724,7 @@ updateInstanceInClass(InstanceOp firrtlInstance, hw::HierPathOp hierPath, // Get the referenced module to get its name. auto referencedModule = - dyn_cast(firrtlInstance.getReferencedModule(symbolTable)); + firrtlInstance.getReferencedModule(instanceGraph); StringRef moduleName = referencedModule.getName(); @@ -821,7 +819,7 @@ updateInstancesInModule(FModuleOp moduleOp, InstanceGraph &instanceGraph, } static LogicalResult updateObjectsAndInstancesInClass( - om::ClassOp classOp, const SymbolTable &symbolTable, + om::ClassOp classOp, InstanceGraph &instanceGraph, const LoweringState &state, SmallVectorImpl &opsToErase) { OpBuilder builder(classOp); auto &classState = state.classLoweringStateTable.at(classOp); @@ -831,7 +829,7 @@ static LogicalResult updateObjectsAndInstancesInClass( if (failed(updateObjectInClass(objectOp, opsToErase))) return failure(); } else if (auto instanceOp = dyn_cast(op)) { - if (failed(updateInstanceInClass(instanceOp, *it++, symbolTable, + if (failed(updateInstanceInClass(instanceOp, *it++, instanceGraph, opsToErase))) return failure(); } @@ -842,7 +840,6 @@ static LogicalResult updateObjectsAndInstancesInClass( // Update Object or Module instantiations in a FIRRTL Module or OM Class. LogicalResult LowerClassesPass::updateInstances(Operation *op, InstanceGraph &instanceGraph, - const SymbolTable &symbolTable, const LoweringState &state) { // Track ops to erase at the end. We can't do this eagerly, since we want to @@ -860,8 +857,8 @@ LogicalResult LowerClassesPass::updateInstances(Operation *op, .Case([&](om::ClassOp classOp) { // Convert FIRRTL Module instance within a Class to OM // Object instance. - return updateObjectsAndInstancesInClass(classOp, symbolTable, state, - opsToErase); + return updateObjectsAndInstancesInClass(classOp, instanceGraph, + state, opsToErase); }) .Default([](auto *op) { return success(); }); if (failed(result)) diff --git a/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp b/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp index 7d24c59ba524..ba034889444b 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp @@ -1444,7 +1444,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) { SmallVector newNames; SmallVector newPortAnno; PreserveAggregate::PreserveMode mode = getPreservationModeForModule( - cast(op.getReferencedModule(symTbl))); + cast(op.getReferencedOperation(symTbl))); endFields.push_back(0); for (size_t i = 0, e = op.getNumResults(); i != e; ++i) { diff --git a/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp b/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp index 1698b4db4213..b9cfe5dbe275 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp @@ -517,7 +517,7 @@ class LowerXMRPass : public LowerXMRBase { // Propagate the reachable RefSendOp across modules. LogicalResult handleInstanceOp(InstanceOp inst, InstanceGraph &instanceGraph) { - Operation *mod = instanceGraph.getReferencedModule(inst); + Operation *mod = inst.getReferencedModule(instanceGraph); if (auto extRefMod = dyn_cast(mod)) { // Extern modules can generate RefType ports, they have an attached // attribute which specifies the internal path into the extern module. diff --git a/lib/Dialect/FIRRTL/Transforms/PrefixModules.cpp b/lib/Dialect/FIRRTL/Transforms/PrefixModules.cpp index 49333cda2ae6..ebc75e67890f 100644 --- a/lib/Dialect/FIRRTL/Transforms/PrefixModules.cpp +++ b/lib/Dialect/FIRRTL/Transforms/PrefixModules.cpp @@ -205,8 +205,7 @@ void PrefixModulesPass::renameModuleBody(std::string prefix, StringRef oldName, newPrefix = StringAttr::get(context, prefix); memOp->setAttr("prefix", newPrefix); } else if (auto instanceOp = dyn_cast(op)) { - auto target = dyn_cast( - *instanceGraph->getReferencedModule(instanceOp)); + auto target = instanceOp.getReferencedModule(*instanceGraph); // Skip all external modules, unless one of the following conditions // is true: diff --git a/lib/Dialect/FIRRTL/Transforms/ResolvePaths.cpp b/lib/Dialect/FIRRTL/Transforms/ResolvePaths.cpp index 99598e888838..2566a6d139d7 100644 --- a/lib/Dialect/FIRRTL/Transforms/ResolvePaths.cpp +++ b/lib/Dialect/FIRRTL/Transforms/ResolvePaths.cpp @@ -64,7 +64,7 @@ struct PathResolver { // through the list of instances looking for the first module which is // multiply instantiated. We will start our HierPathOp at this instance. auto *it = llvm::find_if(target.instances, [&](InstanceOp instance) { - auto *node = instanceGraph[instanceGraph.getReferencedModule(instance)]; + auto *node = instanceGraph.lookup(instance.getReferencedModuleNameAttr()); return !node->hasOneUse(); }); diff --git a/lib/Dialect/HW/HWReductions.cpp b/lib/Dialect/HW/HWReductions.cpp index 63fd703ba0b6..1051cf9c9416 100644 --- a/lib/Dialect/HW/HWReductions.cpp +++ b/lib/Dialect/HW/HWReductions.cpp @@ -34,10 +34,12 @@ struct ModuleSizeCache { uint64_t size = 1; module->walk([&](Operation *op) { size += 1; - if (auto instOp = dyn_cast(op)) + if (auto instOp = dyn_cast(op)) { + auto *node = instanceGraph.lookup(instOp.getReferencedModuleNameAttr()); if (auto instModule = - instanceGraph.getReferencedModule(instOp)) + dyn_cast_or_null(*node->getModule())) size += getModuleSize(instModule, instanceGraph); + } }); moduleSizes.insert({module, size}); return size; diff --git a/lib/Dialect/Ibis/Transforms/IbisTunneling.cpp b/lib/Dialect/Ibis/Transforms/IbisTunneling.cpp index f3e51d408d82..d8ffaa56d526 100644 --- a/lib/Dialect/Ibis/Transforms/IbisTunneling.cpp +++ b/lib/Dialect/Ibis/Transforms/IbisTunneling.cpp @@ -287,8 +287,8 @@ LogicalResult Tunneler::tunnelDown(InstanceGraphNode *currentContainer, // We're not in the target, but tunneling into a child instance. // Create output ports in the child instance for the requested ports. - auto *tunnelScopeNode = ig.lookup(ig.getReferencedModule( - cast(tunnelInstance.getOperation()))); + auto *tunnelScopeNode = + ig.lookup(tunnelInstance.getReferencedModuleNameAttr()); auto tunnelScope = tunnelScopeNode->getModule(); rewriter.setInsertionPointToEnd(tunnelScope.getBodyBlock()); diff --git a/lib/LogicalEquivalence/LogicExporter.cpp b/lib/LogicalEquivalence/LogicExporter.cpp index 83b50d93aa41..0c0689dcc0d1 100644 --- a/lib/LogicalEquivalence/LogicExporter.cpp +++ b/lib/LogicalEquivalence/LogicExporter.cpp @@ -74,8 +74,8 @@ struct Visitor : public hw::StmtVisitor, //===--------------------------------------------------------------------===// LogicalResult visitStmt(hw::InstanceOp op) { - if (auto hwModule = - llvm::dyn_cast(op.getReferencedModuleSlow())) { + if (auto hwModule = llvm::dyn_cast( + op.getReferencedModuleCached(/*cache=*/nullptr))) { circuit->addInstance(op.getInstanceName(), hwModule, op->getOperands(), op->getResults()); return success();