From 9e3f863bfcaeb88fcb639ae85c478b666f94ac0c Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Wed, 2 Oct 2024 12:37:05 -0600 Subject: [PATCH] [FIRRTL] Re-implement old EmitOMIR ports logic in LowerClasses. (#7651) This transformation used to exist in EmitOMIR to detect the presence of a specific field, and use it to add another specific field on the fly. While we could do this through other means downstream, for full compatibility we are re-implementing this logic. Because we are already doing many expensive passes through the object IR, this is implemented throughout the existing methods in LowerClasses, rather than adding a new pass or a new traversal to an existing pass. One consequence of this choice is we do need to perform some global mutation to collect objects that need to have this field added in a multi-threaded context. To support this, a mutex is added to each instance of the pass, and the collected objects are then processed serially. Otherwise, this is fairly straightforward. When we are declaring classes, we check if we need to add the extra field. When we are adding class bodies, we check if we need to add the extra block argument. When we are converting object instances, we check if we need to supply the extra argument. It is in this final case that we actually make a list of objects. Now that we are off the JSON based OMIR, this is the proper way to add a new field with the same type. --- .../Dialect/FIRRTL/FIRRTLAnnotationHelper.h | 2 + .../FIRRTL/Transforms/LowerClasses.cpp | 252 ++++++++++++++++-- test/Dialect/FIRRTL/lower-classes.mlir | 50 ++++ 3 files changed, 284 insertions(+), 20 deletions(-) diff --git a/include/circt/Dialect/FIRRTL/FIRRTLAnnotationHelper.h b/include/circt/Dialect/FIRRTL/FIRRTLAnnotationHelper.h index aee473f85ee0..91ddc5b8bea2 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLAnnotationHelper.h +++ b/include/circt/Dialect/FIRRTL/FIRRTLAnnotationHelper.h @@ -319,6 +319,8 @@ struct HierPathCache { return FlatSymbolRefAttr::get(getSymFor(attr)); } + const SymbolTable &getSymbolTable() const { return symbolTable; } + private: OpBuilder builder; DenseMap cache; diff --git a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp index 398949c55605..99ebd1446063 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp @@ -198,6 +198,13 @@ struct LoweringState { DenseMap classLoweringStateTable; }; +/// Helper struct to capture state about an object that needs RtlPorts added. +struct RtlPortsInfo { + firrtl::PathOp containingModuleRef; + Value basePath; + om::ObjectOp object; +}; + struct LowerClassesPass : public circt::firrtl::impl::LowerClassesBase { void runOnOperation() override; @@ -213,7 +220,8 @@ struct LowerClassesPass // Create an OM Class op from a FIRRTL Class op. om::ClassLike createClass(FModuleLike moduleLike, - const PathInfoTable &pathInfoTable); + const PathInfoTable &pathInfoTable, + std::mutex &intraPassMutex); // Lower the FIRRTL Class to OM Class. void lowerClassLike(FModuleLike moduleLike, om::ClassLike classLike, @@ -225,7 +233,13 @@ struct LowerClassesPass // Update Object instantiations in a FIRRTL Module or OM Class. LogicalResult updateInstances(Operation *op, InstanceGraph &instanceGraph, const LoweringState &state, - const PathInfoTable &pathInfoTable); + const PathInfoTable &pathInfoTable, + std::mutex &intraPassMutex); + + /// Create and add all 'ports' lists of RtlPort objects for each object. + void createAllRtlPorts(const PathInfoTable &pathInfoTable, + hw::InnerSymbolNamespaceCollection &namespaces, + HierPathCache &hierPathCache); // Convert to OM ops and types in Classes or Modules. LogicalResult dialectConversion( @@ -234,6 +248,9 @@ struct LowerClassesPass // State to memoize repeated calls to shouldCreateClass. DenseMap shouldCreateClassMemo; + + // State used while creating the optional 'ports' list of RtlPort objects. + SmallVector rtlPortsToCreate; }; struct PathTracker { @@ -294,6 +311,116 @@ struct PathTracker { SetVector altBasePathRoots; }; +/// Constants and helpers for creating the RtlPorts on the fly. + +static constexpr StringRef kContainingModuleName = "containingModule"; +static constexpr StringRef kPortsName = "ports"; +static constexpr StringRef kRtlPortClassName = "RtlPort"; + +static Type getRtlPortsType(MLIRContext *context) { + return om::ListType::get(om::ClassType::get( + context, FlatSymbolRefAttr::get(context, kRtlPortClassName))); +} + +/// Create and add the 'ports' list of RtlPort objects for an object. +static void createRtlPorts(const RtlPortsInfo &rtlPortToCreate, + const PathInfoTable &pathInfoTable, + hw::InnerSymbolNamespaceCollection &namespaces, + HierPathCache &hierPathCache, OpBuilder &builder) { + firrtl::PathOp containingModuleRef = rtlPortToCreate.containingModuleRef; + Value basePath = rtlPortToCreate.basePath; + om::ObjectOp object = rtlPortToCreate.object; + + // Set the builder to just before the object. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(object); + + // Look up the module from the containingModuleRef. + + FlatSymbolRefAttr containingModulePathRef = + pathInfoTable.table.at(containingModuleRef.getTarget()).symRef; + + const SymbolTable &symbolTable = hierPathCache.getSymbolTable(); + + hw::HierPathOp containingModulePath = + symbolTable.lookup(containingModulePathRef.getAttr()); + + assert(containingModulePath.isModule() && + "expected containing module path to target a module"); + + StringAttr moduleName = containingModulePath.leafMod(); + + FModuleLike mod = symbolTable.lookup(moduleName); + MLIRContext *ctx = mod.getContext(); + Location loc = mod.getLoc(); + + // Create the per-port information. + + auto portClassName = StringAttr::get(ctx, kRtlPortClassName); + auto portClassType = + om::ClassType::get(ctx, FlatSymbolRefAttr::get(portClassName)); + + SmallVector ports; + for (unsigned i = 0, e = mod.getNumPorts(); i < e; ++i) { + // Only process ports that are not zero-width. + auto portType = type_dyn_cast(mod.getPortType(i)); + if (!portType || portType.getBitWidthOrSentinel() == 0) + continue; + + // Get a path to the port. This may modify port attributes or the global + // namespace of hierpaths, so use the mutex around those operations. + + auto portTarget = PortAnnoTarget(mod, i); + + auto portSym = + getInnerRefTo({portTarget.getPortNo(), portTarget.getOp(), 0}, + [&](FModuleLike m) -> hw::InnerSymbolNamespace & { + return namespaces[m]; + }); + + FlatSymbolRefAttr portPathRef = + hierPathCache.getRefFor(ArrayAttr::get(ctx, {portSym})); + + auto portPath = builder.create( + loc, om::PathType::get(ctx), + om::TargetKindAttr::get(ctx, om::TargetKind::DontTouch), basePath, + portPathRef); + + // Get a direction attribute. + + StringRef portDirectionName = + mod.getPortDirection(i) == Direction::Out ? "Output" : "Input"; + + auto portDirection = builder.create( + loc, om::StringType::get(ctx), + StringAttr::get(portDirectionName, om::StringType::get(ctx))); + + // Get a width attribute. + + auto portWidth = builder.create( + loc, om::OMIntegerType::get(ctx), + om::IntegerAttr::get( + ctx, mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64), + portType.getBitWidthOrSentinel()))); + + // Create an RtlPort object for this port, and add it to the list. + + auto portObj = builder.create( + loc, portClassType, portClassName, + ArrayRef{portPath, portDirection, portWidth}); + + ports.push_back(portObj); + } + + // Create a list of RtlPort objects to be included with the containingModule. + + auto portsList = builder.create( + UnknownLoc::get(builder.getContext()), + getRtlPortsType(builder.getContext()), ports); + + object.getActualParamsMutable().append({portsList}); +} + } // namespace LogicalResult @@ -701,6 +828,7 @@ LogicalResult LowerClassesPass::processPaths( /// Lower FIRRTL Class and Object ops to OM Class and Object ops void LowerClassesPass::runOnOperation() { MLIRContext *ctx = &getContext(); + auto intraPassMutex = std::mutex(); // Get the CircuitOp. CircuitOp circuit = getOperation(); @@ -742,7 +870,7 @@ void LowerClassesPass::runOnOperation() { continue; if (shouldCreateClass(moduleLike.getModuleNameAttr())) { - auto omClass = createClass(moduleLike, pathInfoTable); + auto omClass = createClass(moduleLike, pathInfoTable, intraPassMutex); auto &classLoweringState = loweringState.classLoweringStateTable[omClass]; classLoweringState.moduleLike = moduleLike; @@ -804,10 +932,14 @@ void LowerClassesPass::runOnOperation() { if (failed( mlir::failableParallelForEach(ctx, objectContainers, [&](auto *op) { return updateInstances(op, instanceGraph, loweringState, - pathInfoTable); + pathInfoTable, intraPassMutex); }))) return signalPassFailure(); + // If needed, create and add 'ports' lists of RtlPort objects. + if (!rtlPortsToCreate.empty()) + createAllRtlPorts(pathInfoTable, namespaces, cache); + // Convert to OM ops and types in Classes or Modules in parallel. if (failed( mlir::failableParallelForEach(ctx, objectContainers, [&](auto *op) { @@ -817,6 +949,9 @@ void LowerClassesPass::runOnOperation() { // We keep the instance graph up to date, so mark that analysis preserved. markAnalysesPreserved(); + + // Reset pass state. + rtlPortsToCreate.clear(); } std::unique_ptr circt::firrtl::createLowerClassesPass() { @@ -830,9 +965,9 @@ bool LowerClassesPass::shouldCreateClass(StringAttr modName) { } // Create an OM Class op from a FIRRTL Class op or Module op with properties. -om::ClassLike -LowerClassesPass::createClass(FModuleLike moduleLike, - const PathInfoTable &pathInfoTable) { +om::ClassLike LowerClassesPass::createClass(FModuleLike moduleLike, + const PathInfoTable &pathInfoTable, + std::mutex &intraPassMutex) { // Collect the parameter names from input properties. SmallVector formalParamNames; // Every class gets a base path as its first parameter. @@ -845,12 +980,24 @@ LowerClassesPass::createClass(FModuleLike moduleLike, formalParamNames.push_back(StringAttr::get( moduleLike->getContext(), "alt_basepath_" + llvm::Twine(i))); - for (auto [index, port] : llvm::enumerate(moduleLike.getPorts())) - if (port.isInput() && isa(port.type)) + // Collect the input parameters. + bool hasContainingModule = false; + for (auto [index, port] : llvm::enumerate(moduleLike.getPorts())) { + if (port.isInput() && isa(port.type)) { formalParamNames.push_back(port.name); + // Check if we have a 'containingModule' field. + if (port.name.strref().starts_with(kContainingModuleName)) + hasContainingModule = true; + } + } + OpBuilder builder = OpBuilder::atBlockEnd(getOperation().getBodyBlock()); + // If there is a 'containingModule', add a parameter for 'ports'. + if (hasContainingModule) + formalParamNames.push_back(kPortsName); + // Take the name from the FIRRTL Class or Module to create the OM Class name. StringRef className = moduleLike.getName(); @@ -898,15 +1045,21 @@ void LowerClassesPass::lowerClass(om::ClassOp classOp, FModuleLike moduleLike, // Collect information about property ports. SmallVector inputProperties; BitVector portsToErase(moduleLike.getNumPorts()); + bool hasContainingModule = false; for (auto [index, port] : llvm::enumerate(moduleLike.getPorts())) { // For Module ports that aren't property types, move along. if (!isa(port.type)) continue; // Remember input properties to create the OM Class formal parameters. - if (port.isInput()) + if (port.isInput()) { inputProperties.push_back({index, port.name, port.type, port.loc}); + // Check if we have a 'containingModule' field. + if (port.name.strref().starts_with(kContainingModuleName)) + hasContainingModule = true; + } + // In case this is a Module, remember to erase this port. portsToErase.set(index); } @@ -976,6 +1129,14 @@ void LowerClassesPass::lowerClass(om::ClassOp classOp, FModuleLike moduleLike, op.erase(); } + // If there is a 'containingModule', add an argument for 'ports', and a field. + if (hasContainingModule) { + BlockArgument argumentValue = classBody->addArgument( + getRtlPortsType(&getContext()), UnknownLoc::get(&getContext())); + builder.create(argumentValue.getLoc(), kPortsName, + argumentValue); + } + // If the module-like is a Class, it will be completely erased later. // Otherwise, erase just the property ports and ops. if (!isa(moduleLike.getOperation())) { @@ -1029,10 +1190,10 @@ void LowerClassesPass::lowerClassExtern(ClassExternOp classExternOp, // Helper to update an Object instantiation. FIRRTL Object instances are // converted to OM Object instances. -static LogicalResult -updateObjectInClass(firrtl::ObjectOp firrtlObject, - const PathInfoTable &pathInfoTable, - SmallVectorImpl &opsToErase) { +static LogicalResult updateObjectInClass( + firrtl::ObjectOp firrtlObject, const PathInfoTable &pathInfoTable, + SmallVectorImpl &rtlPortsToCreate, std::mutex &intraPassMutex, + SmallVectorImpl &opsToErase) { // The 0'th argument is the base path. auto basePath = firrtlObject->getBlock()->getArgument(0); // build a table mapping the indices of input ports to their position in the @@ -1067,6 +1228,7 @@ updateObjectInClass(firrtl::ObjectOp firrtlObject, for (auto [i, altBasePath] : llvm::enumerate(altBasePaths)) args[1 + i] = altBasePath; // + 1 to skip default base path + firrtl::PathOp containingModuleRef; for (auto *user : llvm::make_early_inc_range(firrtlObject->getUsers())) { if (auto subfield = dyn_cast(user)) { auto index = subfield.getIndex(); @@ -1088,6 +1250,16 @@ updateObjectInClass(firrtl::ObjectOp firrtlObject, if (dst == subfield.getResult()) { args[argIndexTable[index]] = src; opsToErase.push_back(propassign); + + // Check if we have a 'containingModule' field. + if (firrtlClassType.getElement(index).name.strref().starts_with( + kContainingModuleName)) { + assert(!containingModuleRef && + "expected exactly one containingModule"); + assert(isa_and_nonnull(src.getDefiningOp()) && + "expected containingModule to be a PathOp"); + containingModuleRef = src.getDefiningOp(); + } } } } @@ -1114,9 +1286,16 @@ updateObjectInClass(firrtl::ObjectOp firrtlObject, // Create the new Object op. OpBuilder builder(firrtlObject); + auto object = builder.create( firrtlObject.getLoc(), classType, firrtlObject.getClassNameAttr(), args); + // If there is a 'containingModule', track that we need to add 'ports'. + if (containingModuleRef) { + std::lock_guard guard(intraPassMutex); + rtlPortsToCreate.push_back({containingModuleRef, basePath, object}); + } + // Replace uses of the FIRRTL Object with the OM Object. The later dialect // conversion will take care of converting the types. firrtlObject.replaceAllUsesWith(object.getResult()); @@ -1275,13 +1454,15 @@ updateInstancesInModule(FModuleOp moduleOp, InstanceGraph &instanceGraph, static LogicalResult updateObjectsAndInstancesInClass( om::ClassOp classOp, InstanceGraph &instanceGraph, const LoweringState &state, const PathInfoTable &pathInfoTable, + SmallVectorImpl &rtlPortsToCreate, std::mutex &intraPassMutex, SmallVectorImpl &opsToErase) { OpBuilder builder(classOp); auto &classState = state.classLoweringStateTable.at(classOp); auto it = classState.paths.begin(); for (auto &op : classOp->getRegion(0).getOps()) { if (auto objectOp = dyn_cast(op)) { - if (failed(updateObjectInClass(objectOp, pathInfoTable, opsToErase))) + if (failed(updateObjectInClass(objectOp, pathInfoTable, rtlPortsToCreate, + intraPassMutex, opsToErase))) return failure(); } else if (auto instanceOp = dyn_cast(op)) { if (failed(updateInstanceInClass(instanceOp, *it++, instanceGraph, @@ -1293,10 +1474,9 @@ static LogicalResult updateObjectsAndInstancesInClass( } // Update Object or Module instantiations in a FIRRTL Module or OM Class. -LogicalResult -LowerClassesPass::updateInstances(Operation *op, InstanceGraph &instanceGraph, - const LoweringState &state, - const PathInfoTable &pathInfoTable) { +LogicalResult LowerClassesPass::updateInstances( + Operation *op, InstanceGraph &instanceGraph, const LoweringState &state, + const PathInfoTable &pathInfoTable, std::mutex &intraPassMutex) { // Track ops to erase at the end. We can't do this eagerly, since we want to // loop over each op in the container's body, and we may end up removing some @@ -1314,7 +1494,8 @@ LowerClassesPass::updateInstances(Operation *op, InstanceGraph &instanceGraph, // Convert FIRRTL Module instance within a Class to OM // Object instance. return updateObjectsAndInstancesInClass( - classOp, instanceGraph, state, pathInfoTable, opsToErase); + classOp, instanceGraph, state, pathInfoTable, rtlPortsToCreate, + intraPassMutex, opsToErase); }) .Default([](auto *op) { return success(); }); if (failed(result)) @@ -1326,6 +1507,34 @@ LowerClassesPass::updateInstances(Operation *op, InstanceGraph &instanceGraph, return success(); } +// Create and add all 'ports' lists of RtlPort objects for each object. +void LowerClassesPass::createAllRtlPorts( + const PathInfoTable &pathInfoTable, + hw::InnerSymbolNamespaceCollection &namespaces, + HierPathCache &hierPathCache) { + MLIRContext *ctx = &getContext(); + + // Get a builder initialized to the end of the top-level module. + OpBuilder builder = OpBuilder::atBlockEnd(getOperation().getBodyBlock()); + + // Declare an RtlPort class on the fly. + om::ClassOp::buildSimpleClassOp( + builder, UnknownLoc::get(ctx), kRtlPortClassName, + {"ref", "direction", "width"}, {"ref", "direction", "width"}, + {om::PathType::get(ctx), om::StringType::get(ctx), + om::OMIntegerType::get(ctx)}); + + // Sort the collected rtlPortsToCreate and process each. + llvm::stable_sort(rtlPortsToCreate, [](auto lhs, auto rhs) { + return lhs.object.getClassName() < rhs.object.getClassName(); + }); + + // Create each 'ports' list. + for (auto rtlPortToCreate : rtlPortsToCreate) + createRtlPorts(rtlPortToCreate, pathInfoTable, namespaces, hierPathCache, + builder); +} + // Pattern rewriters for dialect conversion. struct FIntegerConstantOpConversion @@ -1831,6 +2040,9 @@ static void populateTypeConverter(TypeConverter &converter) { // Convert FIRRTL List type to OM List type. auto convertListType = [&converter](auto type) -> std::optional { + // If the element type is already in the OM dialect, there's nothing to do. + if (isa(type.getElementType().getDialect())) + return type; auto elementType = converter.convertType(type.getElementType()); if (!elementType) return {}; diff --git a/test/Dialect/FIRRTL/lower-classes.mlir b/test/Dialect/FIRRTL/lower-classes.mlir index 02fcd5613822..82f518075c6b 100644 --- a/test/Dialect/FIRRTL/lower-classes.mlir +++ b/test/Dialect/FIRRTL/lower-classes.mlir @@ -494,3 +494,53 @@ firrtl.circuit "LocalPath" { firrtl.instance child1 @Child() } } + +// CHECK-LABEL: firrtl.circuit "RTLPorts" +firrtl.circuit "RTLPorts" { + // CHECK: hw.hierpath private [[INPUT_NLA:@.+]] [@SomeModule::[[INPUT_SYM:@.+]]] + // CHECK: hw.hierpath private [[OUTPUT_NLA:@.+]] [@SomeModule::[[OUTPUT_SYM:@.+]]] + + // CHECK: firrtl.module @SomeModule + // CHECK-SAME: in %input: !firrtl.uint<1> sym [[INPUT_SYM]] + // CHECK-SAME: out %output: !firrtl.uint<8> sym [[OUTPUT_SYM]] + firrtl.module @SomeModule(in %input: !firrtl.uint<1>, out %output: !firrtl.uint<8>) attributes {annotations = [{class = "circt.tracker", id = distinct[0]<>}]} { + } + + firrtl.module @RTLPorts() { + firrtl.instance inst @SomeModule(in input: !firrtl.uint<1>, out output: !firrtl.uint<8>) + + // CHECK: [[MODULE_PATH:%.+]] = om.path_create instance %basepath {{.+}} + %path = firrtl.path instance distinct[0]<> + + // CHECK: [[INPUT_REF:%.+]] = om.path_create dont_touch %basepath [[INPUT_NLA]] + // CHECK: [[INPUT_DIR:%.+]] = om.constant "Input" + // CHECK: [[INPUT_WIDTH:%.+]] = om.constant #om.integer<1 : i64> + // CHECK: [[INPUT_OBJ:%.+]] = om.object @RtlPort([[INPUT_REF]], [[INPUT_DIR]], [[INPUT_WIDTH]]) + + // CHECK: [[OUTPUT_REF:%.+]] = om.path_create dont_touch %basepath [[OUTPUT_NLA]] + // CHECK: [[OUTPUT_DIR:%.+]] = om.constant "Output" + // CHECK: [[OUTPUT_WIDTH:%.+]] = om.constant #om.integer<8 : i64> + // CHECK: [[OUTPUT_OBJ:%.+]] = om.object @RtlPort([[OUTPUT_REF]], [[OUTPUT_DIR]], [[OUTPUT_WIDTH]]) + + // CHECK: [[PORTS_LIST:%.+]] = om.list_create [[INPUT_OBJ]], [[OUTPUT_OBJ]] + + // CHECK: om.object @NeedsRTLPorts(%basepath, [[MODULE_PATH]], [[PORTS_LIST]]) + %object = firrtl.object @NeedsRTLPorts(in containingModule_in: !firrtl.path, out containingModule: !firrtl.path) + %field = firrtl.object.subfield %object[containingModule_in] : !firrtl.class<@NeedsRTLPorts(in containingModule_in: !firrtl.path, out containingModule: !firrtl.path)> + firrtl.propassign %field, %path : !firrtl.path + + // Add a second instance to ensure the RtlPorts class is only declared once. + %object2 = firrtl.object @NeedsRTLPorts(in containingModule_in: !firrtl.path, out containingModule: !firrtl.path) + %field2 = firrtl.object.subfield %object2[containingModule_in] : !firrtl.class<@NeedsRTLPorts(in containingModule_in: !firrtl.path, out containingModule: !firrtl.path)> + firrtl.propassign %field2, %path : !firrtl.path + } + + // CHECK: om.class @NeedsRTLPorts + // CHECK-NEXT: om.class.field @containingModule + // CHECK-NEXT: om.class.field @ports + firrtl.class @NeedsRTLPorts(in %containingModule_in: !firrtl.path, out %containingModule: !firrtl.path) { + firrtl.propassign %containingModule, %containingModule_in : !firrtl.path + } + + // CHECK-COUNT-1: om.class @RtlPort +}