diff --git a/include/circt/Dialect/FIRRTL/FIRRTLAnnotationHelper.h b/include/circt/Dialect/FIRRTL/FIRRTLAnnotationHelper.h index aee473f85ee0..91ddc5b8bea2 100644 --- a/include/circt/Dialect/FIRRTL/FIRRTLAnnotationHelper.h +++ b/include/circt/Dialect/FIRRTL/FIRRTLAnnotationHelper.h @@ -319,6 +319,8 @@ struct HierPathCache { return FlatSymbolRefAttr::get(getSymFor(attr)); } + const SymbolTable &getSymbolTable() const { return symbolTable; } + private: OpBuilder builder; DenseMap cache; diff --git a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp index 398949c55605..99ebd1446063 100644 --- a/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp +++ b/lib/Dialect/FIRRTL/Transforms/LowerClasses.cpp @@ -198,6 +198,13 @@ struct LoweringState { DenseMap classLoweringStateTable; }; +/// Helper struct to capture state about an object that needs RtlPorts added. +struct RtlPortsInfo { + firrtl::PathOp containingModuleRef; + Value basePath; + om::ObjectOp object; +}; + struct LowerClassesPass : public circt::firrtl::impl::LowerClassesBase { void runOnOperation() override; @@ -213,7 +220,8 @@ struct LowerClassesPass // Create an OM Class op from a FIRRTL Class op. om::ClassLike createClass(FModuleLike moduleLike, - const PathInfoTable &pathInfoTable); + const PathInfoTable &pathInfoTable, + std::mutex &intraPassMutex); // Lower the FIRRTL Class to OM Class. void lowerClassLike(FModuleLike moduleLike, om::ClassLike classLike, @@ -225,7 +233,13 @@ struct LowerClassesPass // Update Object instantiations in a FIRRTL Module or OM Class. LogicalResult updateInstances(Operation *op, InstanceGraph &instanceGraph, const LoweringState &state, - const PathInfoTable &pathInfoTable); + const PathInfoTable &pathInfoTable, + std::mutex &intraPassMutex); + + /// Create and add all 'ports' lists of RtlPort objects for each object. + void createAllRtlPorts(const PathInfoTable &pathInfoTable, + hw::InnerSymbolNamespaceCollection &namespaces, + HierPathCache &hierPathCache); // Convert to OM ops and types in Classes or Modules. LogicalResult dialectConversion( @@ -234,6 +248,9 @@ struct LowerClassesPass // State to memoize repeated calls to shouldCreateClass. DenseMap shouldCreateClassMemo; + + // State used while creating the optional 'ports' list of RtlPort objects. + SmallVector rtlPortsToCreate; }; struct PathTracker { @@ -294,6 +311,116 @@ struct PathTracker { SetVector altBasePathRoots; }; +/// Constants and helpers for creating the RtlPorts on the fly. + +static constexpr StringRef kContainingModuleName = "containingModule"; +static constexpr StringRef kPortsName = "ports"; +static constexpr StringRef kRtlPortClassName = "RtlPort"; + +static Type getRtlPortsType(MLIRContext *context) { + return om::ListType::get(om::ClassType::get( + context, FlatSymbolRefAttr::get(context, kRtlPortClassName))); +} + +/// Create and add the 'ports' list of RtlPort objects for an object. +static void createRtlPorts(const RtlPortsInfo &rtlPortToCreate, + const PathInfoTable &pathInfoTable, + hw::InnerSymbolNamespaceCollection &namespaces, + HierPathCache &hierPathCache, OpBuilder &builder) { + firrtl::PathOp containingModuleRef = rtlPortToCreate.containingModuleRef; + Value basePath = rtlPortToCreate.basePath; + om::ObjectOp object = rtlPortToCreate.object; + + // Set the builder to just before the object. + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(object); + + // Look up the module from the containingModuleRef. + + FlatSymbolRefAttr containingModulePathRef = + pathInfoTable.table.at(containingModuleRef.getTarget()).symRef; + + const SymbolTable &symbolTable = hierPathCache.getSymbolTable(); + + hw::HierPathOp containingModulePath = + symbolTable.lookup(containingModulePathRef.getAttr()); + + assert(containingModulePath.isModule() && + "expected containing module path to target a module"); + + StringAttr moduleName = containingModulePath.leafMod(); + + FModuleLike mod = symbolTable.lookup(moduleName); + MLIRContext *ctx = mod.getContext(); + Location loc = mod.getLoc(); + + // Create the per-port information. + + auto portClassName = StringAttr::get(ctx, kRtlPortClassName); + auto portClassType = + om::ClassType::get(ctx, FlatSymbolRefAttr::get(portClassName)); + + SmallVector ports; + for (unsigned i = 0, e = mod.getNumPorts(); i < e; ++i) { + // Only process ports that are not zero-width. + auto portType = type_dyn_cast(mod.getPortType(i)); + if (!portType || portType.getBitWidthOrSentinel() == 0) + continue; + + // Get a path to the port. This may modify port attributes or the global + // namespace of hierpaths, so use the mutex around those operations. + + auto portTarget = PortAnnoTarget(mod, i); + + auto portSym = + getInnerRefTo({portTarget.getPortNo(), portTarget.getOp(), 0}, + [&](FModuleLike m) -> hw::InnerSymbolNamespace & { + return namespaces[m]; + }); + + FlatSymbolRefAttr portPathRef = + hierPathCache.getRefFor(ArrayAttr::get(ctx, {portSym})); + + auto portPath = builder.create( + loc, om::PathType::get(ctx), + om::TargetKindAttr::get(ctx, om::TargetKind::DontTouch), basePath, + portPathRef); + + // Get a direction attribute. + + StringRef portDirectionName = + mod.getPortDirection(i) == Direction::Out ? "Output" : "Input"; + + auto portDirection = builder.create( + loc, om::StringType::get(ctx), + StringAttr::get(portDirectionName, om::StringType::get(ctx))); + + // Get a width attribute. + + auto portWidth = builder.create( + loc, om::OMIntegerType::get(ctx), + om::IntegerAttr::get( + ctx, mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64), + portType.getBitWidthOrSentinel()))); + + // Create an RtlPort object for this port, and add it to the list. + + auto portObj = builder.create( + loc, portClassType, portClassName, + ArrayRef{portPath, portDirection, portWidth}); + + ports.push_back(portObj); + } + + // Create a list of RtlPort objects to be included with the containingModule. + + auto portsList = builder.create( + UnknownLoc::get(builder.getContext()), + getRtlPortsType(builder.getContext()), ports); + + object.getActualParamsMutable().append({portsList}); +} + } // namespace LogicalResult @@ -701,6 +828,7 @@ LogicalResult LowerClassesPass::processPaths( /// Lower FIRRTL Class and Object ops to OM Class and Object ops void LowerClassesPass::runOnOperation() { MLIRContext *ctx = &getContext(); + auto intraPassMutex = std::mutex(); // Get the CircuitOp. CircuitOp circuit = getOperation(); @@ -742,7 +870,7 @@ void LowerClassesPass::runOnOperation() { continue; if (shouldCreateClass(moduleLike.getModuleNameAttr())) { - auto omClass = createClass(moduleLike, pathInfoTable); + auto omClass = createClass(moduleLike, pathInfoTable, intraPassMutex); auto &classLoweringState = loweringState.classLoweringStateTable[omClass]; classLoweringState.moduleLike = moduleLike; @@ -804,10 +932,14 @@ void LowerClassesPass::runOnOperation() { if (failed( mlir::failableParallelForEach(ctx, objectContainers, [&](auto *op) { return updateInstances(op, instanceGraph, loweringState, - pathInfoTable); + pathInfoTable, intraPassMutex); }))) return signalPassFailure(); + // If needed, create and add 'ports' lists of RtlPort objects. + if (!rtlPortsToCreate.empty()) + createAllRtlPorts(pathInfoTable, namespaces, cache); + // Convert to OM ops and types in Classes or Modules in parallel. if (failed( mlir::failableParallelForEach(ctx, objectContainers, [&](auto *op) { @@ -817,6 +949,9 @@ void LowerClassesPass::runOnOperation() { // We keep the instance graph up to date, so mark that analysis preserved. markAnalysesPreserved(); + + // Reset pass state. + rtlPortsToCreate.clear(); } std::unique_ptr circt::firrtl::createLowerClassesPass() { @@ -830,9 +965,9 @@ bool LowerClassesPass::shouldCreateClass(StringAttr modName) { } // Create an OM Class op from a FIRRTL Class op or Module op with properties. -om::ClassLike -LowerClassesPass::createClass(FModuleLike moduleLike, - const PathInfoTable &pathInfoTable) { +om::ClassLike LowerClassesPass::createClass(FModuleLike moduleLike, + const PathInfoTable &pathInfoTable, + std::mutex &intraPassMutex) { // Collect the parameter names from input properties. SmallVector formalParamNames; // Every class gets a base path as its first parameter. @@ -845,12 +980,24 @@ LowerClassesPass::createClass(FModuleLike moduleLike, formalParamNames.push_back(StringAttr::get( moduleLike->getContext(), "alt_basepath_" + llvm::Twine(i))); - for (auto [index, port] : llvm::enumerate(moduleLike.getPorts())) - if (port.isInput() && isa(port.type)) + // Collect the input parameters. + bool hasContainingModule = false; + for (auto [index, port] : llvm::enumerate(moduleLike.getPorts())) { + if (port.isInput() && isa(port.type)) { formalParamNames.push_back(port.name); + // Check if we have a 'containingModule' field. + if (port.name.strref().starts_with(kContainingModuleName)) + hasContainingModule = true; + } + } + OpBuilder builder = OpBuilder::atBlockEnd(getOperation().getBodyBlock()); + // If there is a 'containingModule', add a parameter for 'ports'. + if (hasContainingModule) + formalParamNames.push_back(kPortsName); + // Take the name from the FIRRTL Class or Module to create the OM Class name. StringRef className = moduleLike.getName(); @@ -898,15 +1045,21 @@ void LowerClassesPass::lowerClass(om::ClassOp classOp, FModuleLike moduleLike, // Collect information about property ports. SmallVector inputProperties; BitVector portsToErase(moduleLike.getNumPorts()); + bool hasContainingModule = false; for (auto [index, port] : llvm::enumerate(moduleLike.getPorts())) { // For Module ports that aren't property types, move along. if (!isa(port.type)) continue; // Remember input properties to create the OM Class formal parameters. - if (port.isInput()) + if (port.isInput()) { inputProperties.push_back({index, port.name, port.type, port.loc}); + // Check if we have a 'containingModule' field. + if (port.name.strref().starts_with(kContainingModuleName)) + hasContainingModule = true; + } + // In case this is a Module, remember to erase this port. portsToErase.set(index); } @@ -976,6 +1129,14 @@ void LowerClassesPass::lowerClass(om::ClassOp classOp, FModuleLike moduleLike, op.erase(); } + // If there is a 'containingModule', add an argument for 'ports', and a field. + if (hasContainingModule) { + BlockArgument argumentValue = classBody->addArgument( + getRtlPortsType(&getContext()), UnknownLoc::get(&getContext())); + builder.create(argumentValue.getLoc(), kPortsName, + argumentValue); + } + // If the module-like is a Class, it will be completely erased later. // Otherwise, erase just the property ports and ops. if (!isa(moduleLike.getOperation())) { @@ -1029,10 +1190,10 @@ void LowerClassesPass::lowerClassExtern(ClassExternOp classExternOp, // Helper to update an Object instantiation. FIRRTL Object instances are // converted to OM Object instances. -static LogicalResult -updateObjectInClass(firrtl::ObjectOp firrtlObject, - const PathInfoTable &pathInfoTable, - SmallVectorImpl &opsToErase) { +static LogicalResult updateObjectInClass( + firrtl::ObjectOp firrtlObject, const PathInfoTable &pathInfoTable, + SmallVectorImpl &rtlPortsToCreate, std::mutex &intraPassMutex, + SmallVectorImpl &opsToErase) { // The 0'th argument is the base path. auto basePath = firrtlObject->getBlock()->getArgument(0); // build a table mapping the indices of input ports to their position in the @@ -1067,6 +1228,7 @@ updateObjectInClass(firrtl::ObjectOp firrtlObject, for (auto [i, altBasePath] : llvm::enumerate(altBasePaths)) args[1 + i] = altBasePath; // + 1 to skip default base path + firrtl::PathOp containingModuleRef; for (auto *user : llvm::make_early_inc_range(firrtlObject->getUsers())) { if (auto subfield = dyn_cast(user)) { auto index = subfield.getIndex(); @@ -1088,6 +1250,16 @@ updateObjectInClass(firrtl::ObjectOp firrtlObject, if (dst == subfield.getResult()) { args[argIndexTable[index]] = src; opsToErase.push_back(propassign); + + // Check if we have a 'containingModule' field. + if (firrtlClassType.getElement(index).name.strref().starts_with( + kContainingModuleName)) { + assert(!containingModuleRef && + "expected exactly one containingModule"); + assert(isa_and_nonnull(src.getDefiningOp()) && + "expected containingModule to be a PathOp"); + containingModuleRef = src.getDefiningOp(); + } } } } @@ -1114,9 +1286,16 @@ updateObjectInClass(firrtl::ObjectOp firrtlObject, // Create the new Object op. OpBuilder builder(firrtlObject); + auto object = builder.create( firrtlObject.getLoc(), classType, firrtlObject.getClassNameAttr(), args); + // If there is a 'containingModule', track that we need to add 'ports'. + if (containingModuleRef) { + std::lock_guard guard(intraPassMutex); + rtlPortsToCreate.push_back({containingModuleRef, basePath, object}); + } + // Replace uses of the FIRRTL Object with the OM Object. The later dialect // conversion will take care of converting the types. firrtlObject.replaceAllUsesWith(object.getResult()); @@ -1275,13 +1454,15 @@ updateInstancesInModule(FModuleOp moduleOp, InstanceGraph &instanceGraph, static LogicalResult updateObjectsAndInstancesInClass( om::ClassOp classOp, InstanceGraph &instanceGraph, const LoweringState &state, const PathInfoTable &pathInfoTable, + SmallVectorImpl &rtlPortsToCreate, std::mutex &intraPassMutex, SmallVectorImpl &opsToErase) { OpBuilder builder(classOp); auto &classState = state.classLoweringStateTable.at(classOp); auto it = classState.paths.begin(); for (auto &op : classOp->getRegion(0).getOps()) { if (auto objectOp = dyn_cast(op)) { - if (failed(updateObjectInClass(objectOp, pathInfoTable, opsToErase))) + if (failed(updateObjectInClass(objectOp, pathInfoTable, rtlPortsToCreate, + intraPassMutex, opsToErase))) return failure(); } else if (auto instanceOp = dyn_cast(op)) { if (failed(updateInstanceInClass(instanceOp, *it++, instanceGraph, @@ -1293,10 +1474,9 @@ static LogicalResult updateObjectsAndInstancesInClass( } // Update Object or Module instantiations in a FIRRTL Module or OM Class. -LogicalResult -LowerClassesPass::updateInstances(Operation *op, InstanceGraph &instanceGraph, - const LoweringState &state, - const PathInfoTable &pathInfoTable) { +LogicalResult LowerClassesPass::updateInstances( + Operation *op, InstanceGraph &instanceGraph, const LoweringState &state, + const PathInfoTable &pathInfoTable, std::mutex &intraPassMutex) { // Track ops to erase at the end. We can't do this eagerly, since we want to // loop over each op in the container's body, and we may end up removing some @@ -1314,7 +1494,8 @@ LowerClassesPass::updateInstances(Operation *op, InstanceGraph &instanceGraph, // Convert FIRRTL Module instance within a Class to OM // Object instance. return updateObjectsAndInstancesInClass( - classOp, instanceGraph, state, pathInfoTable, opsToErase); + classOp, instanceGraph, state, pathInfoTable, rtlPortsToCreate, + intraPassMutex, opsToErase); }) .Default([](auto *op) { return success(); }); if (failed(result)) @@ -1326,6 +1507,34 @@ LowerClassesPass::updateInstances(Operation *op, InstanceGraph &instanceGraph, return success(); } +// Create and add all 'ports' lists of RtlPort objects for each object. +void LowerClassesPass::createAllRtlPorts( + const PathInfoTable &pathInfoTable, + hw::InnerSymbolNamespaceCollection &namespaces, + HierPathCache &hierPathCache) { + MLIRContext *ctx = &getContext(); + + // Get a builder initialized to the end of the top-level module. + OpBuilder builder = OpBuilder::atBlockEnd(getOperation().getBodyBlock()); + + // Declare an RtlPort class on the fly. + om::ClassOp::buildSimpleClassOp( + builder, UnknownLoc::get(ctx), kRtlPortClassName, + {"ref", "direction", "width"}, {"ref", "direction", "width"}, + {om::PathType::get(ctx), om::StringType::get(ctx), + om::OMIntegerType::get(ctx)}); + + // Sort the collected rtlPortsToCreate and process each. + llvm::stable_sort(rtlPortsToCreate, [](auto lhs, auto rhs) { + return lhs.object.getClassName() < rhs.object.getClassName(); + }); + + // Create each 'ports' list. + for (auto rtlPortToCreate : rtlPortsToCreate) + createRtlPorts(rtlPortToCreate, pathInfoTable, namespaces, hierPathCache, + builder); +} + // Pattern rewriters for dialect conversion. struct FIntegerConstantOpConversion @@ -1831,6 +2040,9 @@ static void populateTypeConverter(TypeConverter &converter) { // Convert FIRRTL List type to OM List type. auto convertListType = [&converter](auto type) -> std::optional { + // If the element type is already in the OM dialect, there's nothing to do. + if (isa(type.getElementType().getDialect())) + return type; auto elementType = converter.convertType(type.getElementType()); if (!elementType) return {}; diff --git a/test/Dialect/FIRRTL/lower-classes.mlir b/test/Dialect/FIRRTL/lower-classes.mlir index 02fcd5613822..82f518075c6b 100644 --- a/test/Dialect/FIRRTL/lower-classes.mlir +++ b/test/Dialect/FIRRTL/lower-classes.mlir @@ -494,3 +494,53 @@ firrtl.circuit "LocalPath" { firrtl.instance child1 @Child() } } + +// CHECK-LABEL: firrtl.circuit "RTLPorts" +firrtl.circuit "RTLPorts" { + // CHECK: hw.hierpath private [[INPUT_NLA:@.+]] [@SomeModule::[[INPUT_SYM:@.+]]] + // CHECK: hw.hierpath private [[OUTPUT_NLA:@.+]] [@SomeModule::[[OUTPUT_SYM:@.+]]] + + // CHECK: firrtl.module @SomeModule + // CHECK-SAME: in %input: !firrtl.uint<1> sym [[INPUT_SYM]] + // CHECK-SAME: out %output: !firrtl.uint<8> sym [[OUTPUT_SYM]] + firrtl.module @SomeModule(in %input: !firrtl.uint<1>, out %output: !firrtl.uint<8>) attributes {annotations = [{class = "circt.tracker", id = distinct[0]<>}]} { + } + + firrtl.module @RTLPorts() { + firrtl.instance inst @SomeModule(in input: !firrtl.uint<1>, out output: !firrtl.uint<8>) + + // CHECK: [[MODULE_PATH:%.+]] = om.path_create instance %basepath {{.+}} + %path = firrtl.path instance distinct[0]<> + + // CHECK: [[INPUT_REF:%.+]] = om.path_create dont_touch %basepath [[INPUT_NLA]] + // CHECK: [[INPUT_DIR:%.+]] = om.constant "Input" + // CHECK: [[INPUT_WIDTH:%.+]] = om.constant #om.integer<1 : i64> + // CHECK: [[INPUT_OBJ:%.+]] = om.object @RtlPort([[INPUT_REF]], [[INPUT_DIR]], [[INPUT_WIDTH]]) + + // CHECK: [[OUTPUT_REF:%.+]] = om.path_create dont_touch %basepath [[OUTPUT_NLA]] + // CHECK: [[OUTPUT_DIR:%.+]] = om.constant "Output" + // CHECK: [[OUTPUT_WIDTH:%.+]] = om.constant #om.integer<8 : i64> + // CHECK: [[OUTPUT_OBJ:%.+]] = om.object @RtlPort([[OUTPUT_REF]], [[OUTPUT_DIR]], [[OUTPUT_WIDTH]]) + + // CHECK: [[PORTS_LIST:%.+]] = om.list_create [[INPUT_OBJ]], [[OUTPUT_OBJ]] + + // CHECK: om.object @NeedsRTLPorts(%basepath, [[MODULE_PATH]], [[PORTS_LIST]]) + %object = firrtl.object @NeedsRTLPorts(in containingModule_in: !firrtl.path, out containingModule: !firrtl.path) + %field = firrtl.object.subfield %object[containingModule_in] : !firrtl.class<@NeedsRTLPorts(in containingModule_in: !firrtl.path, out containingModule: !firrtl.path)> + firrtl.propassign %field, %path : !firrtl.path + + // Add a second instance to ensure the RtlPorts class is only declared once. + %object2 = firrtl.object @NeedsRTLPorts(in containingModule_in: !firrtl.path, out containingModule: !firrtl.path) + %field2 = firrtl.object.subfield %object2[containingModule_in] : !firrtl.class<@NeedsRTLPorts(in containingModule_in: !firrtl.path, out containingModule: !firrtl.path)> + firrtl.propassign %field2, %path : !firrtl.path + } + + // CHECK: om.class @NeedsRTLPorts + // CHECK-NEXT: om.class.field @containingModule + // CHECK-NEXT: om.class.field @ports + firrtl.class @NeedsRTLPorts(in %containingModule_in: !firrtl.path, out %containingModule: !firrtl.path) { + firrtl.propassign %containingModule, %containingModule_in : !firrtl.path + } + + // CHECK-COUNT-1: om.class @RtlPort +}