Skip to content

Commit

Permalink
[InstanceGraph] Remove the module lookup helper (#6425)
Browse files Browse the repository at this point in the history
The instance graph is transitioning towards supporting multiple possible targets for instance operations.
As a consequence, it can no longer assume that there are instance operations with a single target and it will expose a generic interface returning a list of targets.

This PR removes the lookup helper from the graph and instead provides the users (`firrtl::InstanceOp` in particular) with helpers to fetch the unique referenced instance.
  • Loading branch information
nandor authored Nov 21, 2023
1 parent 2711dd5 commit 0e31fd7
Show file tree
Hide file tree
Showing 25 changed files with 89 additions and 85 deletions.
17 changes: 16 additions & 1 deletion include/circt/Dialect/FIRRTL/FIRRTLDeclarations.td
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ def InstanceOp : HardwareDeclOp<"instance", [
/// the new InstanceOp to the same location.
InstanceOp cloneAndInsertPorts(ArrayRef<std::pair<unsigned, PortInfo>> ports);

//===------------------------------------------------------------------===//
// Instance graph methods
//===------------------------------------------------------------------===//

// Quick lookup of the referenced module using the instance graph.
template <typename T = ::circt::igraph::ModuleOpInterface>
T getReferencedModule(::circt::igraph::InstanceGraph &instanceGraph) {
auto moduleNameAttr = getModuleNameAttr().getAttr();
auto *node = instanceGraph.lookup(moduleNameAttr);
if (!node)
return nullptr;
Operation *moduleOp = node->getModule();
return dyn_cast_or_null<T>(moduleOp);
}

//===------------------------------------------------------------------===//
// PortList Methods
//===------------------------------------------------------------------===//
Expand Down Expand Up @@ -598,7 +613,7 @@ def ObjectOp : FIRRTLOp<"object", [
ParentOneOf<["firrtl::FModuleOp, firrtl::ClassOp"]>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<FInstanceLike, [
"getReferencedModule"
"getReferencedOperation"
]>,
DeclareOpInterfaceMethods<InstanceGraphInstanceOpInterface, [
"getReferencedModuleName",
Expand Down
2 changes: 1 addition & 1 deletion include/circt/Dialect/FIRRTL/FIRRTLOpInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def FInstanceLike : OpInterface<"FInstanceLike", [

let methods = [
InterfaceMethod<"Get the referenced module via a symbol table.",
"::mlir::Operation *", "getReferencedModule",
"::mlir::Operation *", "getReferencedOperation",
(ins "const SymbolTable&":$symtbl),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
Expand Down
1 change: 1 addition & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "circt/Dialect/HW/InnerSymbolTable.h"
#include "circt/Dialect/Seq/SeqAttributes.h"
#include "circt/Support/FieldRef.h"
#include "circt/Support/InstanceGraph.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/RegionKindInterface.h"
Expand Down
12 changes: 2 additions & 10 deletions include/circt/Support/InstanceGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,6 @@ class InstanceGraph {
/// Lookup an InstanceGraphNode for a module.
InstanceGraphNode *operator[](ModuleOpInterface op) { return lookup(op); }

/// Look up the referenced module from an InstanceOp. This will use a
/// hashtable lookup to find the module, where
/// InstanceOp.getReferencedModule() will be a linear search through the IR.
template <typename TTarget = ModuleOpInterface>
auto getReferencedModule(InstanceOpInterface op) {
return cast<TTarget>(getReferencedModuleImpl(op).getOperation());
}

/// Check if child is instantiated by a parent.
bool isAncestor(ModuleOpInterface child, ModuleOpInterface parent);

Expand Down Expand Up @@ -330,8 +322,8 @@ static T &operator<<(T &os, const InstancePath &path) {
return path.print(os);
}

/// A data structure that caches and provides absolute paths to module instances
/// in the IR.
/// A data structure that caches and provides absolute paths to module
/// instances in the IR.
struct InstancePathCache {
/// The instance graph of the IR.
InstanceGraph &instanceGraph;
Expand Down
15 changes: 8 additions & 7 deletions lib/Conversion/FIRRTLToHW/LowerToHW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ struct CircuitLoweringState {

CircuitLoweringState(CircuitOp circuitOp, bool enableAnnotationWarning,
bool emitChiselAssertsAsSVA,
InstanceGraph *instanceGraph, NLATable *nlaTable)
InstanceGraph &instanceGraph, NLATable *nlaTable)
: circuitOp(circuitOp), instanceGraph(instanceGraph),
enableAnnotationWarning(enableAnnotationWarning),
emitChiselAssertsAsSVA(emitChiselAssertsAsSVA), nlaTable(nlaTable) {
Expand All @@ -235,7 +235,7 @@ struct CircuitLoweringState {
// Figure out which module is the DUT and TestHarness. If there is no
// module marked as the DUT, the top module is the DUT. If the DUT and the
// test harness are the same, then there is no test harness.
testHarness = instanceGraph->getTopLevelModule();
testHarness = instanceGraph.getTopLevelModule();
if (!dut) {
dut = testHarness;
testHarness = nullptr;
Expand Down Expand Up @@ -282,7 +282,7 @@ struct CircuitLoweringState {
// Returns false if the module is not instantiated by the DUT.
bool isInDUT(igraph::ModuleOpInterface child) {
if (auto parent = dyn_cast<igraph::ModuleOpInterface>(*dut))
return getInstanceGraph()->isAncestor(child, parent);
return getInstanceGraph().isAncestor(child, parent);
return dut == child;
}

Expand All @@ -293,7 +293,7 @@ struct CircuitLoweringState {
// Harness is not known.
bool isInTestHarness(igraph::ModuleOpInterface mod) { return !isInDUT(mod); }

InstanceGraph *getInstanceGraph() { return instanceGraph; }
InstanceGraph &getInstanceGraph() { return instanceGraph; }

/// Given a type, return the corresponding lowered type for the HW dialect.
/// A wrapper to the FIRRTLUtils::lowerType, required to ensure safe addition
Expand All @@ -316,7 +316,7 @@ struct CircuitLoweringState {

/// Cache of module symbols. We need to test hirarchy-based properties to
/// lower annotaitons.
InstanceGraph *instanceGraph;
InstanceGraph &instanceGraph;

// Record the set of remaining annotation classes. This is used to warn only
// once about any annotation class.
Expand Down Expand Up @@ -548,7 +548,7 @@ void FIRRTLModuleLowering::runOnOperation() {
// if lowering failed.
CircuitLoweringState state(
circuit, enableAnnotationWarning, emitChiselAssertsAsSVA,
&getAnalysis<InstanceGraph>(), &getAnalysis<NLATable>());
getAnalysis<InstanceGraph>(), &getAnalysis<NLATable>());

SmallVector<hw::HWModuleOp, 32> modulesToProcess;

Expand Down Expand Up @@ -3062,7 +3062,8 @@ LogicalResult FIRRTLLowering::visitDecl(MemOp op) {

LogicalResult FIRRTLLowering::visitDecl(InstanceOp oldInstance) {
Operation *oldModule =
circuitState.getInstanceGraph()->getReferencedModule(oldInstance);
oldInstance.getReferencedModule(circuitState.getInstanceGraph());

auto newModule = circuitState.getNewModule(oldModule);
if (!newModule) {
oldInstance->emitOpError("could not find module [")
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/FIRRTL/FIRRTLAnnotationHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ firrtl::resolveEntities(TokenAnnoTarget path, CircuitOp circuit,
ArrayRef<TargetToken> component(path.component);
if (auto instance = dyn_cast<InstanceOp>(ref.getOp())) {
instances.push_back(instance);
auto target = cast<FModuleLike>(instance.getReferencedModule(symTbl));
auto target = cast<FModuleLike>(instance.getReferencedOperation(symTbl));
if (component.empty()) {
ref = OpAnnoTarget(target);
} else if (component.front().isIndex) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/FIRRTL/FIRRTLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2919,7 +2919,7 @@ ClassLike ObjectOp::getReferencedClass(const SymbolTable &symbolTable) {
return symbolTable.lookup<ClassLike>(symRef.getLeafReference());
}

Operation *ObjectOp::getReferencedModule(const SymbolTable &symtbl) {
Operation *ObjectOp::getReferencedOperation(const SymbolTable &symtbl) {
return getReferencedClass(symtbl);
}

Expand Down
17 changes: 9 additions & 8 deletions lib/Dialect/FIRRTL/FIRRTLReductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ findInstantiatedModule(firrtl::InstanceOp instOp,
::detail::SymbolCache &symbols) {
auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
auto moduleOp = dyn_cast<firrtl::FModuleOp>(
instOp.getReferencedModule(symbols.getSymbolTable(tableOp)));
instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
return moduleOp ? std::optional(moduleOp) : std::nullopt;
}

Expand Down Expand Up @@ -338,7 +338,7 @@ struct InstanceStubber : public OpReduction<firrtl::InstanceOp> {
auto *tableOp = SymbolTable::getNearestSymbolTable(op);
op->walk([&](firrtl::InstanceOp instOp) {
auto moduleOp = cast<firrtl::FModuleLike>(
instOp.getReferencedModule(symbols.getSymbolTable(tableOp)));
instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
deadInsts.insert(instOp);
if (llvm::all_of(
symbols.getSymbolUserMap(tableOp).getUsers(moduleOp),
Expand Down Expand Up @@ -385,7 +385,7 @@ struct InstanceStubber : public OpReduction<firrtl::InstanceOp> {
}
auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
auto moduleOp = cast<firrtl::FModuleLike>(
instOp.getReferencedModule(symbols.getSymbolTable(tableOp)));
instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
nlaRemover.markNLAsInOperation(instOp);
erasedInsts.insert(instOp);
if (llvm::all_of(
Expand Down Expand Up @@ -658,12 +658,12 @@ struct ExtmoduleInstanceRemover : public OpReduction<firrtl::InstanceOp> {

uint64_t match(firrtl::InstanceOp instOp) override {
return isa<firrtl::FExtModuleOp>(
instOp.getReferencedModule(symbols.getNearestSymbolTable(instOp)));
instOp.getReferencedOperation(symbols.getNearestSymbolTable(instOp)));
}
LogicalResult rewrite(firrtl::InstanceOp instOp) override {
auto portInfo =
cast<firrtl::FModuleLike>(
instOp.getReferencedModule(symbols.getNearestSymbolTable(instOp)))
cast<firrtl::FModuleLike>(instOp.getReferencedOperation(
symbols.getNearestSymbolTable(instOp)))
.getPorts();
ImplicitLocOpBuilder builder(instOp.getLoc(), instOp);
SmallVector<Value> replacementWires;
Expand Down Expand Up @@ -895,7 +895,8 @@ struct EagerInliner : public OpReduction<firrtl::InstanceOp> {

uint64_t match(firrtl::InstanceOp instOp) override {
auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
auto moduleOp = instOp.getReferencedModule(symbols.getSymbolTable(tableOp));
auto *moduleOp =
instOp.getReferencedOperation(symbols.getSymbolTable(tableOp));
if (!isa<firrtl::FModuleOp>(moduleOp))
return 0;
return symbols.getSymbolUserMap(tableOp).getUsers(moduleOp).size() == 1;
Expand All @@ -921,7 +922,7 @@ struct EagerInliner : public OpReduction<firrtl::InstanceOp> {
}
auto *tableOp = SymbolTable::getNearestSymbolTable(instOp);
auto moduleOp = cast<firrtl::FModuleOp>(
instOp.getReferencedModule(symbols.getSymbolTable(tableOp)));
instOp.getReferencedOperation(symbols.getSymbolTable(tableOp)));
for (auto &op : llvm::make_early_inc_range(*moduleOp.getBodyBlock())) {
op.remove();
builder.insert(&op);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/FIRRTL/Transforms/AddSeqMemPorts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ LogicalResult AddSeqMemPortsPass::processModule(FModuleOp module, bool isDUT) {

for (auto &op : llvm::make_early_inc_range(*module.getBodyBlock())) {
if (auto inst = dyn_cast<InstanceOp>(op)) {
auto submodule = instanceGraph->getReferencedModule(inst);
auto submodule = inst.getReferencedModule(*instanceGraph);

auto subMemInfoIt = memInfoMap.find(submodule);
// If there are no extra ports, we don't have to do anything.
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/FIRRTL/Transforms/CheckCombLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,7 @@ class DiscoverLoops {
if (auto inst = dyn_cast_or_null<InstanceOp>(ref.getDefiningOp())) {
auto res = cast<OpResult>(ref.getValue());
auto portNum = res.getResultNumber();
auto refMod =
dyn_cast_or_null<FModuleOp>(*instanceGraph.getReferencedModule(inst));
auto refMod = inst.getReferencedModule<FModuleOp>(instanceGraph);
if (!refMod)
return;
FieldRef modArg(refMod.getArgument(portNum), ref.getFieldID());
Expand Down
22 changes: 11 additions & 11 deletions lib/Dialect/FIRRTL/Transforms/Dedup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,20 +620,20 @@ struct Equivalence {
LogicalResult check(InFlightDiagnostic &diag, InstanceOp a, InstanceOp b) {
auto aName = a.getModuleNameAttr().getAttr();
auto bName = b.getModuleNameAttr().getAttr();
if (aName == bName)
return success();

// If the modules instantiate are different we will want to know why the
// sub module did not dedupliate. This code recursively checks the child
// module.
if (aName != bName) {
auto aModule = instanceGraph.getReferencedModule(a);
auto bModule = instanceGraph.getReferencedModule(b);
// Create a new error for the submodule.
diag.attachNote(std::nullopt)
<< "in instance " << a.getNameAttr() << " of " << aName
<< ", and instance " << b.getNameAttr() << " of " << bName;
check(diag, aModule, bModule);
return failure();
}
return success();
auto aModule = a.getReferencedModule(instanceGraph);
auto bModule = b.getReferencedModule(instanceGraph);
// Create a new error for the submodule.
diag.attachNote(std::nullopt)
<< "in instance " << a.getNameAttr() << " of " << aName
<< ", and instance " << b.getNameAttr() << " of " << bName;
check(diag, aModule, bModule);
return failure();
}

// NOLINTNEXTLINE(misc-no-recursion)
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/FIRRTL/Transforms/ExtractInstances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ void ExtractInstancesPass::collectAnnos() {
// Gather the annotations on instances to be extracted.
circuit.walk([&](InstanceOp inst) {
SmallVector<Annotation, 1> instAnnos;
Operation *module = instanceGraph->getReferencedModule(inst);
Operation *module = inst.getReferencedModule(*instanceGraph);

// Module-level annotations.
auto it = annotatedModules.find(module);
Expand Down
9 changes: 4 additions & 5 deletions lib/Dialect/FIRRTL/Transforms/IMConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ void IMConstPropPass::markInvalidValueOp(InvalidValueOp invalid) {
/// enclosing block is marked live. This sets up the def-use edges for ports.
void IMConstPropPass::markInstanceOp(InstanceOp instance) {
// Get the module being reference or a null pointer if this is an extmodule.
Operation *op = instanceGraph->getReferencedModule(instance);
Operation *op = instance.getReferencedModule(*instanceGraph);

// If this is an extmodule, just remember that any results and inouts are
// overdefined.
Expand Down Expand Up @@ -721,12 +721,11 @@ void IMConstPropPass::visitConnectLike(FConnectLike connect,
if (auto instance = dest.getDefiningOp<InstanceOp>()) {
// Update the dest, when its an instance op.
mergeLatticeValue(fieldRefDestConnected, srcValue);
auto module =
dyn_cast<FModuleOp>(*instanceGraph->getReferencedModule(instance));
if (!module)
auto mod = instance.getReferencedModule<FModuleOp>(*instanceGraph);
if (!mod)
return;

BlockArgument modulePortVal = module.getArgument(dest.getResultNumber());
BlockArgument modulePortVal = mod.getArgument(dest.getResultNumber());

return mergeLatticeValue(
FieldRef(modulePortVal, fieldRefDestConnected.getFieldID()),
Expand Down
8 changes: 3 additions & 5 deletions lib/Dialect/FIRRTL/Transforms/IMDeadCodeElim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ struct IMDeadCodeElimPass : public IMDeadCodeElimBase<IMDeadCodeElimPass> {
void IMDeadCodeElimPass::visitInstanceOp(InstanceOp instance) {
markBlockUndeletable(instance);

auto module =
dyn_cast<FModuleOp>(*instanceGraph->getReferencedModule(instance));
auto module = instance.getReferencedModule<FModuleOp>(*instanceGraph);

if (!module)
return;
Expand Down Expand Up @@ -198,7 +197,7 @@ void IMDeadCodeElimPass::visitUser(Operation *op) {

void IMDeadCodeElimPass::markInstanceOp(InstanceOp instance) {
// Get the module being referenced.
Operation *op = instanceGraph->getReferencedModule(instance);
Operation *op = instance.getReferencedModule(*instanceGraph);

// If this is an extmodule, just remember that any inputs and inouts are
// alive.
Expand Down Expand Up @@ -487,8 +486,7 @@ void IMDeadCodeElimPass::visitValue(Value value) {
if (auto instance = value.getDefiningOp<InstanceOp>()) {
auto instanceResult = value.cast<mlir::OpResult>();
// Update the src, when it's an instance op.
auto module =
dyn_cast<FModuleOp>(*instanceGraph->getReferencedModule(instance));
auto module = instance.getReferencedModule<FModuleOp>(*instanceGraph);

// Propagate liveness only when a port is output.
if (!module || module.getPortDirection(instanceResult.getResultNumber()) ==
Expand Down
9 changes: 4 additions & 5 deletions lib/Dialect/FIRRTL/Transforms/InferResets.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ void InferResetsPass::traceResets(CircuitOp circuit) {
/// instance's port values with the target module's port values.
void InferResetsPass::traceResets(InstanceOp inst) {
// Lookup the referenced module. Nothing to do if its an extmodule.
auto module = dyn_cast<FModuleOp>(*instanceGraph->getReferencedModule(inst));
auto module = inst.getReferencedModule<FModuleOp>(*instanceGraph);
if (!module)
return;
LLVM_DEBUG(llvm::dbgs() << "Visiting instance " << inst.getName() << "\n");
Expand Down Expand Up @@ -1114,8 +1114,8 @@ LogicalResult InferResetsPass::updateReset(ResetNetwork net, ResetKind kind) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
moduleWorklist.insert(blockArg.getOwner()->getParentOp());
else if (auto instOp = value.getDefiningOp<InstanceOp>()) {
if (auto extmodule = dyn_cast<FExtModuleOp>(
*instanceGraph->getReferencedModule(instOp)))
if (auto extmodule =
instOp.getReferencedModule<FExtModuleOp>(*instanceGraph))
extmoduleWorklist.insert({extmodule, instOp});
} else if (auto uncast = value.getDefiningOp<UninferredResetCastOp>()) {
uncast.replaceAllUsesWith(uncast.getInput());
Expand Down Expand Up @@ -1742,8 +1742,7 @@ void InferResetsPass::implementAsyncReset(Operation *op, FModuleOp module,
// Lookup the reset domain of the instantiated module. If there is no
// reset domain associated with that module, or the module is explicitly
// marked as being in no domain, simply skip.
auto refModule =
dyn_cast<FModuleOp>(*instanceGraph->getReferencedModule(instOp));
auto refModule = instOp.getReferencedModule<FModuleOp>(*instanceGraph);
if (!refModule)
return;
auto domainIt = domains.find(refModule);
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/FIRRTL/Transforms/InferWidths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,7 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) {

// Handle instances of other modules.
.Case<InstanceOp>([&](auto op) {
auto refdModule = op.getReferencedModule(symtbl);
auto refdModule = op.getReferencedOperation(symtbl);
auto module = dyn_cast<FModuleOp>(&*refdModule);
if (!module) {
auto diag = mlir::emitError(op.getLoc());
Expand Down
9 changes: 5 additions & 4 deletions lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,8 @@ LogicalResult LowerAnnotationsPass::solveWiringProblems(ApplyState &state) {
while (!sources.empty() && !sinks.empty()) {
if (sources.top() != sinks.top())
break;
auto newLCA = sources.top();
lca = cast<FModuleOp>(instanceGraph.getReferencedModule(newLCA));
auto newLCA = cast<InstanceOp>(*sources.top());
lca = cast<FModuleOp>(newLCA.getReferencedModule(instanceGraph));
sources = sources.dropFront();
sinks = sinks.dropFront();
}
Expand Down Expand Up @@ -885,8 +885,9 @@ LogicalResult LowerAnnotationsPass::solveWiringProblems(ApplyState &state) {
auto addPorts = [&](igraph::InstancePath insts, Value val, Type tpe,
Direction dir) {
StringRef name, instName;
for (auto inst : llvm::reverse(insts)) {
auto mod = instanceGraph.getReferencedModule<FModuleOp>(inst);
for (auto instNode : llvm::reverse(insts)) {
auto inst = cast<InstanceOp>(*instNode);
auto mod = inst.getReferencedModule<FModuleOp>(instanceGraph);
if (name.empty()) {
if (problem.newNameHint.empty())
name = state.getNamespace(mod).newName(
Expand Down
Loading

0 comments on commit 0e31fd7

Please sign in to comment.