Skip to content

Commit f77c002

Browse files
authored
[NFC] Cache common lookups in ModuleType (#6892)
Use custom storage for ModuleType to cache input/output <-> index mappings. Speeds up many things in small ways.
1 parent 562f4d7 commit f77c002

File tree

5 files changed

+82
-57
lines changed

5 files changed

+82
-57
lines changed

include/circt/Dialect/HW/HWTypes.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,50 @@ struct ModulePort {
3131
Direction dir;
3232
};
3333

34+
static bool operator==(const ModulePort &a, const ModulePort &b) {
35+
return a.dir == b.dir && a.name == b.name && a.type == b.type;
36+
}
37+
static llvm::hash_code hash_value(const ModulePort &port) {
38+
return llvm::hash_combine(port.dir, port.name, port.type);
39+
}
40+
41+
namespace detail {
42+
struct ModuleTypeStorage : public TypeStorage {
43+
ModuleTypeStorage(ArrayRef<ModulePort> inPorts);
44+
45+
using KeyTy = ArrayRef<ModulePort>;
46+
47+
/// Define the comparison function for the key type.
48+
bool operator==(const KeyTy &key) const {
49+
return std::equal(key.begin(), key.end(), ports.begin(), ports.end());
50+
}
51+
52+
/// Define a hash function for the key type.
53+
static llvm::hash_code hashKey(const KeyTy &key) {
54+
return llvm::hash_combine_range(key.begin(), key.end());
55+
}
56+
57+
/// Define a construction method for creating a new instance of this storage.
58+
static ModuleTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
59+
const KeyTy &key) {
60+
return new (allocator.allocate<ModuleTypeStorage>()) ModuleTypeStorage(key);
61+
}
62+
63+
/// Construct an instance of the key from this storage class.
64+
KeyTy getAsKey() const { return ports; }
65+
66+
ArrayRef<ModulePort> getPorts() const { return ports; }
67+
68+
/// The parametric data held by the storage class.
69+
SmallVector<ModulePort> ports;
70+
// Cache of common lookups
71+
SmallVector<size_t> inputToAbs;
72+
SmallVector<size_t> outputToAbs;
73+
SmallVector<size_t> absToInput;
74+
SmallVector<size_t> absToOutput;
75+
};
76+
} // namespace detail
77+
3478
class HWSymbolCache;
3579
class ParamDeclAttr;
3680
class TypedeclOp;

include/circt/Dialect/HW/HWTypesImpl.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def ModuleTypeImpl : HWType<"Module"> {
243243
let hasCustomAssemblyFormat = 1;
244244
let genVerifyDecl = 1;
245245
let mnemonic = "modty";
246+
let genStorageClass = 0;
246247

247248
let extraClassDeclaration = [{
248249
// Many of these are transitional and will be removed when modules and instances

lib/CAPI/Dialect/FIRRTL.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ FIRRTLValueFlow firrtlValueFoldFlow(MlirValue value, FIRRTLValueFlow flow) {
311311
case Flow::Duplex:
312312
return FIRRTL_VALUE_FLOW_DUPLEX;
313313
}
314+
llvm_unreachable("invalid flow");
314315
}
315316

316317
bool firrtlImportAnnotationsFromJSONRaw(

lib/Dialect/HW/HWOps.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,8 +1075,6 @@ static LogicalResult verifyModuleCommon(HWModuleLike module) {
10751075
assert(isa<HWModuleLike>(module) &&
10761076
"verifier hook should only be called on modules");
10771077

1078-
auto moduleType = module.getHWModuleType();
1079-
10801078
SmallPtrSet<Attribute, 4> paramNames;
10811079

10821080
// Check parameter default values are sensible.

lib/Dialect/HW/HWTypes.cpp

Lines changed: 36 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -823,60 +823,30 @@ LogicalResult ModuleType::verify(function_ref<InFlightDiagnostic()> emitError,
823823
}
824824

825825
size_t ModuleType::getPortIdForInputId(size_t idx) {
826-
for (auto [i, p] : llvm::enumerate(getPorts())) {
827-
if (p.dir != ModulePort::Direction::Output) {
828-
if (!idx)
829-
return i;
830-
--idx;
831-
}
832-
}
833-
assert(0 && "Out of bounds input port id");
834-
return ~0UL;
826+
assert(idx < getImpl()->inputToAbs.size() && "input port out of range");
827+
return getImpl()->inputToAbs[idx];
835828
}
836829

837830
size_t ModuleType::getPortIdForOutputId(size_t idx) {
838-
for (auto [i, p] : llvm::enumerate(getPorts())) {
839-
if (p.dir == ModulePort::Direction::Output) {
840-
if (!idx)
841-
return i;
842-
--idx;
843-
}
844-
}
845-
assert(0 && "Out of bounds output port id");
846-
return ~0UL;
831+
assert(idx < getImpl()->outputToAbs.size() && " output port out of range");
832+
return getImpl()->outputToAbs[idx];
847833
}
848834

849835
size_t ModuleType::getInputIdForPortId(size_t idx) {
850-
auto ports = getPorts();
851-
assert(ports[idx].dir != ModulePort::Direction::Output);
852-
size_t retval = 0;
853-
for (size_t i = 0; i < idx; ++i)
854-
if (ports[i].dir != ModulePort::Direction::Output)
855-
++retval;
856-
return retval;
836+
auto nIdx = getImpl()->absToInput[idx];
837+
assert(nIdx != ~0ULL);
838+
return nIdx;
857839
}
858840

859841
size_t ModuleType::getOutputIdForPortId(size_t idx) {
860-
auto ports = getPorts();
861-
assert(ports[idx].dir == ModulePort::Direction::Output);
862-
size_t retval = 0;
863-
for (size_t i = 0; i < idx; ++i)
864-
if (ports[i].dir == ModulePort::Direction::Output)
865-
++retval;
866-
return retval;
842+
auto nIdx = getImpl()->absToOutput[idx];
843+
assert(nIdx != ~0ULL);
844+
return nIdx;
867845
}
868846

869-
size_t ModuleType::getNumInputs() {
870-
return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) {
871-
return p.dir != ModulePort::Direction::Output;
872-
});
873-
}
847+
size_t ModuleType::getNumInputs() { return getImpl()->inputToAbs.size(); }
874848

875-
size_t ModuleType::getNumOutputs() {
876-
return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) {
877-
return p.dir == ModulePort::Direction::Output;
878-
});
879-
}
849+
size_t ModuleType::getNumOutputs() { return getImpl()->outputToAbs.size(); }
880850

881851
size_t ModuleType::getNumPorts() { return getPorts().size(); }
882852

@@ -984,6 +954,10 @@ FunctionType ModuleType::getFuncType() {
984954
return FunctionType::get(getContext(), inputs, outputs);
985955
}
986956

957+
ArrayRef<ModulePort> ModuleType::getPorts() const {
958+
return getImpl()->getPorts();
959+
}
960+
987961
FailureOr<ModuleType> ModuleType::resolveParametricTypes(ArrayAttr parameters,
988962
LocationAttr loc,
989963
bool emitErrors) {
@@ -1021,7 +995,7 @@ static ModulePort::Direction strToDir(StringRef str) {
1021995
}
1022996

1023997
/// Parse a list of field names and types within <>. E.g.:
1024-
/// <foo: i7, bar: i8>
998+
/// <input foo: i7, output bar: i8>
1025999
static ParseResult parsePorts(AsmParser &p,
10261000
SmallVectorImpl<ModulePort> &ports) {
10271001
return p.parseCommaSeparatedList(
@@ -1060,18 +1034,6 @@ void ModuleType::print(AsmPrinter &odsPrinter) const {
10601034
printPorts(odsPrinter, getPorts());
10611035
}
10621036

1063-
namespace circt {
1064-
namespace hw {
1065-
1066-
static bool operator==(const ModulePort &a, const ModulePort &b) {
1067-
return a.dir == b.dir && a.name == b.name && a.type == b.type;
1068-
}
1069-
static llvm::hash_code hash_value(const ModulePort &port) {
1070-
return llvm::hash_combine(port.dir, port.name, port.type);
1071-
}
1072-
} // namespace hw
1073-
} // namespace circt
1074-
10751037
ModuleType circt::hw::detail::fnToMod(Operation *op,
10761038
ArrayRef<Attribute> inputNames,
10771039
ArrayRef<Attribute> outputNames) {
@@ -1109,6 +1071,25 @@ ModuleType circt::hw::detail::fnToMod(FunctionType fnty,
11091071
return ModuleType::get(fnty.getContext(), ports);
11101072
}
11111073

1074+
detail::ModuleTypeStorage::ModuleTypeStorage(ArrayRef<ModulePort> inPorts)
1075+
: ports(inPorts) {
1076+
size_t nextInput = 0;
1077+
size_t nextOutput = 0;
1078+
for (auto [idx, p] : llvm::enumerate(ports)) {
1079+
if (p.dir == ModulePort::Direction::Output) {
1080+
outputToAbs.push_back(idx);
1081+
absToOutput.push_back(nextOutput);
1082+
absToInput.push_back(~0ULL);
1083+
++nextOutput;
1084+
} else {
1085+
inputToAbs.push_back(idx);
1086+
absToInput.push_back(nextInput);
1087+
absToOutput.push_back(~0ULL);
1088+
++nextInput;
1089+
}
1090+
}
1091+
}
1092+
11121093
////////////////////////////////////////////////////////////////////////////////
11131094
// BoilerPlate
11141095
////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)