Skip to content

Commit

Permalink
[FIRRTL] Add a new FIRRTL annotation to specify type lowering behavio…
Browse files Browse the repository at this point in the history
…r of module body (#7751)

Add a new annotation to control type lowering behavior for internal signals
within a module, separate from the port convention. This allows more fine-grained
control over how aggregate types are handled inside modules.

The new annotation works similarly to ConventionAnnotation but applies to
internal signals rather than module ports. It supports the same conventions
and includes an 'includeHierarchy' option to apply the setting to all
modules in the hierarchy.
  • Loading branch information
uenoku authored Dec 5, 2024
1 parent baccf51 commit 0c1465d
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 17 deletions.
28 changes: 28 additions & 0 deletions docs/Dialects/FIRRTL/FIRRTLAnnotations.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,34 @@ The options are:
}
```

### BodyTypeLoweringAnnotation

| Property | Type | Description |
| ------------------- | ------ | ---------------------------------------------------- |
| class | string | `circt.BodyTypeLoweringAnnotation` |
| convention | string | See `Convention` annotation |
| target | string | See `Convention` annotation |
| includeHierarchy | bool | Apply the convention to all modules in the hierarchy |

Specify the type lowering option for module internal signals.
This is similar to the `Convention` annotation, but for internal signals
rather than module ports. Refer to the `Convention` annotation for each
property description.

When `includeHierarchy` is `false`, it indicates the convention is applied only to
the specified module. If `includeHierarchy` is `true`, the convention is applied to
all modules in the hierarchy. If there are multiple annotation instances that specify
conventions, the `scalarized` convention takes precedence over the `internal` convention.

```json
{
"class": "circt.BodyTypeLoweringAnnotation",
"convention": "scalarized",
"target": "~Foo|Bar",
"includeHierarchy": true
}
```

### ElaborationArtefactsDirectory

| Property | Type | Description |
Expand Down
2 changes: 2 additions & 0 deletions include/circt/Dialect/FIRRTL/AnnotationDetails.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ constexpr const char *rawAnnotations = "rawAnnotations";
//===----------------------------------------------------------------------===//

constexpr const char *conventionAnnoClass = "circt.ConventionAnnotation";
constexpr const char *typeLoweringAnnoClass =
"circt.BodyTypeLoweringAnnotation";
constexpr const char *dontTouchAnnoClass =
"firrtl.transforms.DontTouchAnnotation";
constexpr const char *enumComponentAnnoClass =
Expand Down
67 changes: 67 additions & 0 deletions lib/Dialect/FIRRTL/Transforms/LowerAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,72 @@ static LogicalResult applyConventionAnno(const AnnoPathValue &target,
return error() << "can only target to a module or extmodule";
}

static LogicalResult applyBodyTypeLoweringAnno(const AnnoPathValue &target,
DictionaryAttr anno,
ApplyState &state) {
auto *op = target.ref.getOp();
auto loc = op->getLoc();
auto error = [&]() {
auto diag = mlir::emitError(loc);
diag << typeLoweringAnnoClass;
return diag;
};

auto opTarget = dyn_cast<OpAnnoTarget>(target.ref);
if (!opTarget)
return error() << "must target a module object";

if (!target.isLocal())
return error() << "must be local";

auto moduleOp = dyn_cast<FModuleOp>(op);

if (!moduleOp)
return error() << "can only target to a module";

auto conventionStrAttr =
tryGetAs<StringAttr>(anno, anno, "convention", loc, conventionAnnoClass);

if (!conventionStrAttr)
return failure();

auto conventionStr = conventionStrAttr.getValue();
auto conventionOpt = parseConvention(conventionStr);
if (!conventionOpt)
return error() << "unknown convention " << conventionStr;

auto convention = *conventionOpt;

if (convention == Convention::Internal)
// Convention is internal by default so there is nothing to change
return success();

auto conventionAttr = ConventionAttr::get(op->getContext(), convention);

// `includeHierarchy` only valid in BodyTypeLowering.
bool includeHierarchy = false;
if (auto includeHierarchyAttr = tryGetAs<BoolAttr>(
anno, anno, "includeHierarchy", loc, conventionAnnoClass))
includeHierarchy = includeHierarchyAttr.getValue();

if (includeHierarchy) {
// If includeHierarchy is true, update the convention for all modules in
// the hierarchy.
for (auto *node :
llvm::post_order(state.instancePathCache.instanceGraph[moduleOp])) {
if (!node)
continue;
if (auto fmodule = dyn_cast<FModuleOp>(*node->getModule()))
fmodule->setAttr("body_type_lowering", conventionAttr);
}
} else {
// Update the convention.
moduleOp->setAttr("body_type_lowering", conventionAttr);
}

return success();
}

static LogicalResult applyModulePrefixAnno(const AnnoPathValue &target,
DictionaryAttr anno,
ApplyState &state) {
Expand Down Expand Up @@ -553,6 +619,7 @@ static llvm::StringMap<AnnoRecord> annotationRecords{{
{memTapBlackboxClass, {stdResolve, applyWithoutTarget<true>}},
// Miscellaneous Annotations
{conventionAnnoClass, {stdResolve, applyConventionAnno}},
{typeLoweringAnnoClass, {stdResolve, applyBodyTypeLoweringAnno}},
{dontTouchAnnoClass,
{stdResolve, applyWithoutTarget<true, true, WireOp, NodeOp, RegOp,
RegResetOp, InstanceOp, MemOp, CombMemOp,
Expand Down
42 changes: 27 additions & 15 deletions lib/Dialect/FIRRTL/Transforms/LowerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,17 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {

TypeLoweringVisitor(
MLIRContext *context, PreserveAggregate::PreserveMode preserveAggregate,
Convention bodyConvention,
PreserveAggregate::PreserveMode memoryPreservationMode,
SymbolTable &symTbl, const AttrCache &cache,
const llvm::DenseMap<FModuleLike, Convention> &conventionTable)
: context(context), aggregatePreservationMode(preserveAggregate),
: context(context), defaultAggregatePreservationMode(preserveAggregate),
memoryPreservationMode(memoryPreservationMode), symTbl(symTbl),
cache(cache), conventionTable(conventionTable) {}
cache(cache), conventionTable(conventionTable) {
bodyAggregatePreservationMode = bodyConvention == Convention::Scalarized
? PreserveAggregate::None
: defaultAggregatePreservationMode;
}
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitDecl;
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitExpr;
using FIRRTLVisitor<TypeLoweringVisitor, bool>::visitStmt;
Expand Down Expand Up @@ -422,7 +427,7 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
Location errorLoc);

PreserveAggregate::PreserveMode
getPreservationModeForModule(FModuleLike moduleLike);
getPreservationModeForPorts(FModuleLike moduleLike);
Value getSubWhatever(Value val, size_t index);

size_t uniqueIdx = 0;
Expand All @@ -434,7 +439,8 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
MLIRContext *context;

/// Aggregate preservation mode.
PreserveAggregate::PreserveMode aggregatePreservationMode;
PreserveAggregate::PreserveMode defaultAggregatePreservationMode;
PreserveAggregate::PreserveMode bodyAggregatePreservationMode;
PreserveAggregate::PreserveMode memoryPreservationMode;

/// The builder is set and maintained in the main loop.
Expand All @@ -453,21 +459,21 @@ struct TypeLoweringVisitor : public FIRRTLVisitor<TypeLoweringVisitor, bool> {
};
} // namespace

/// Return aggregate preservation mode for the module. If the module has a
/// Return aggregate preservation mode for the module ports. If the module has a
/// scalarized linkage, then we may not preserve it's aggregate ports.
PreserveAggregate::PreserveMode
TypeLoweringVisitor::getPreservationModeForModule(FModuleLike module) {
TypeLoweringVisitor::getPreservationModeForPorts(FModuleLike module) {
auto lookup = conventionTable.find(module);
if (lookup == conventionTable.end())
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
switch (lookup->second) {
case Convention::Scalarized:
return PreserveAggregate::None;
case Convention::Internal:
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
}
llvm_unreachable("Unknown convention");
return aggregatePreservationMode;
return defaultAggregatePreservationMode;
}

Value TypeLoweringVisitor::getSubWhatever(Value val, size_t index) {
Expand Down Expand Up @@ -636,7 +642,7 @@ bool TypeLoweringVisitor::lowerProducer(
return false;
SmallVector<FlatBundleFieldEntry, 8> fieldTypes;

if (!peelType(srcFType, fieldTypes, aggregatePreservationMode))
if (!peelType(srcFType, fieldTypes, bodyAggregatePreservationMode))
return false;

SmallVector<Value> lowered;
Expand Down Expand Up @@ -805,7 +811,7 @@ bool TypeLoweringVisitor::lowerArg(FModuleLike module, size_t argIndex,
// Flatten any bundle types.
SmallVector<FlatBundleFieldEntry> fieldTypes;
auto srcType = type_cast<FIRRTLType>(newArgs[argIndex].pi.type);
if (!peelType(srcType, fieldTypes, getPreservationModeForModule(module)))
if (!peelType(srcType, fieldTypes, getPreservationModeForPorts(module)))
return false;

// Ports with internalPath set cannot be lowered.
Expand Down Expand Up @@ -925,7 +931,7 @@ bool TypeLoweringVisitor::visitStmt(RefDefineOp op) {
// Attempt to get the bundle types.
SmallVector<FlatBundleFieldEntry> fields;

if (!peelType(op.getDest().getType(), fields, aggregatePreservationMode))
if (!peelType(op.getDest().getType(), fields, bodyAggregatePreservationMode))
return false;

// Loop over the leaf aggregates.
Expand Down Expand Up @@ -1458,7 +1464,7 @@ bool TypeLoweringVisitor::visitDecl(InstanceOp op) {
SmallVector<Direction> newDirs;
SmallVector<Attribute> newNames;
SmallVector<Attribute> newPortAnno;
PreserveAggregate::PreserveMode mode = getPreservationModeForModule(
PreserveAggregate::PreserveMode mode = getPreservationModeForPorts(
cast<FModuleLike>(op.getReferencedOperation(symTbl)));

endFields.push_back(0);
Expand Down Expand Up @@ -1662,9 +1668,15 @@ void LowerTypesPass::runOnOperation() {

// This lambda, executes in parallel for each Op within the circt.
auto lowerModules = [&](FModuleLike op) -> LogicalResult {
// Use body type lowering attribute if it exists, otherwise use internal.
Convention convention = Convention::Internal;
if (auto conventionAttr = dyn_cast_or_null<ConventionAttr>(
op->getDiscardableAttr("body_type_lowering")))
convention = conventionAttr.getValue();

auto tl =
TypeLoweringVisitor(&getContext(), preserveAggregate, preserveMemories,
symTbl, cache, conventionTable);
TypeLoweringVisitor(&getContext(), preserveAggregate, convention,
preserveMemories, symTbl, cache, conventionTable);
tl.lowerModule(op);

return LogicalResult::failure(tl.isFailed());
Expand Down
23 changes: 21 additions & 2 deletions test/Dialect/FIRRTL/annotations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -734,14 +734,33 @@ firrtl.circuit "Test" attributes {rawAnnotations = [
// -----

firrtl.circuit "Test" attributes {rawAnnotations =[
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"}
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"},
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = false}
]} {
// CHECK: attributes {convention = #firrtl<convention scalarized>}
// CHECK: attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {}
}

// -----

firrtl.circuit "Test" attributes {rawAnnotations = [
{class = "circt.ConventionAnnotation", target = "~Test|Test", convention = "scalarized"},
{class = "circt.BodyTypeLoweringAnnotation", target = "~Test|Test", convention = "scalarized", includeHierarchy = true}
]} {
// CHECK: @Test() attributes {body_type_lowering = #firrtl<convention scalarized>, convention = #firrtl<convention scalarized>}
firrtl.module @Test() attributes {convention = #firrtl<convention internal>} {
firrtl.instance child @Child()
}

// CHECK: @Child() attributes {body_type_lowering = #firrtl<convention scalarized>}
firrtl.module @Child() attributes {convention = #firrtl<convention internal>} {}

// CHECK: @Child2() {
firrtl.module @Child2() attributes {convention = #firrtl<convention internal>} {}
}

// -----

firrtl.circuit "Test" attributes {rawAnnotations =[
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>comb", prefix = "Prefix_"},
{class = "chisel3.ModulePrefixAnnotation", target = "~Test|Test>seq", prefix = "Prefix_"},
Expand Down
39 changes: 39 additions & 0 deletions test/Dialect/FIRRTL/lower-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,45 @@ firrtl.circuit "UnrealizedConversion" {
}
}

firrtl.circuit "Conventions1" {
// COMMON-LABEL: @Conventions1
// AGGREGATE-SAME: %input_0
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
firrtl.module public @Conventions1(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention internal>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions2
// AGGREGATE-SAME: %input_0: !firrtl.uint<8>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.uint<8>
firrtl.module private @Conventions2(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention scalarized>, body_type_lowering = #firrtl<convention scalarized>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions3
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.vector<uint<8>, 1>
firrtl.module private @Conventions3(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention internal>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
// COMMON-LABEL: @Conventions4
// AGGREGATE-SAME: %input: !firrtl.vector<uint<8>, 1>
// AGGREGATE-NEXT: firrtl.reg
// AGGREGATE-SAME: !firrtl.uint<8>
firrtl.module private @Conventions4(in %input: !firrtl.vector<uint<8>, 1>, in %clk: !firrtl.clock, out %port: !firrtl.vector<uint<8>, 1>) attributes {convention = #firrtl<convention internal>, body_type_lowering = #firrtl<convention scalarized>}{
%r = firrtl.reg interesting_name %clk : !firrtl.clock, !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %r, %input : !firrtl.vector<uint<8>, 1>
firrtl.matchingconnect %port, %r : !firrtl.vector<uint<8>, 1>
}
}

// Test that memories have their prefixes copied when lowering.
// See: https://github.com/llvm/circt/issues/7835
firrtl.circuit "MemoryPrefixCopying" {
Expand Down

0 comments on commit 0c1465d

Please sign in to comment.