Skip to content

Commit

Permalink
[RTG][Elaboration] Support interleave_sequences, factor our sequence …
Browse files Browse the repository at this point in the history
…inlining and label resolution
  • Loading branch information
maerhart committed Feb 13, 2025
1 parent d748483 commit 0906f72
Show file tree
Hide file tree
Showing 10 changed files with 478 additions and 155 deletions.
3 changes: 2 additions & 1 deletion include/circt/Dialect/RTG/IR/RTGVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class RTGOpVisitor {
RandomNumberInRangeOp,
// Sequences
SequenceOp, GetSequenceOp, SubstituteSequenceOp,
RandomizeSequenceOp, EmbedSequenceOp,
RandomizeSequenceOp, EmbedSequenceOp, InterleaveSequencesOp,
// Sets
SetCreateOp, SetSelectRandomOp, SetDifferenceOp, SetUnionOp,
SetSizeOp>([&](auto expr) -> ResultType {
Expand Down Expand Up @@ -86,6 +86,7 @@ class RTGOpVisitor {
HANDLE(GetSequenceOp, Unhandled);
HANDLE(SubstituteSequenceOp, Unhandled);
HANDLE(RandomizeSequenceOp, Unhandled);
HANDLE(InterleaveSequencesOp, Unhandled);
HANDLE(EmbedSequenceOp, Unhandled);
HANDLE(RandomNumberInRangeOp, Unhandled);
HANDLE(OnContextOp, Unhandled);
Expand Down
30 changes: 30 additions & 0 deletions include/circt/Dialect/RTG/Transforms/RTGPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,22 @@ def EmitRTGISAAssemblyPass : Pass<"rtg-emit-isa-assembly", "mlir::ModuleOp"> {
];
}

def InlineSequencesPass : Pass<"rtg-inline-sequences", "mlir::ModuleOp"> {
let summary = "inline and interleave sequences";
let description = [{
Inline all sequences into tests and remove the 'rtg.sequence' operations.
Also computes and materializes all interleaved sequences
('interleave_sequences' operation).
}];

let statistics = [
Statistic<"numSequencesInlined", "num-sequences-inlined",
"Number of sequences inlined into another sequence or test.">,
Statistic<"numSequencesInterleaved", "num-sequences-interleaved",
"Number of sequences interleaved with another sequence.">,
];
}

def LinearScanRegisterAllocationPass : Pass<
"rtg-linear-scan-register-allocation", "rtg::TestOp"> {

Expand All @@ -81,4 +97,18 @@ def LinearScanRegisterAllocationPass : Pass<
];
}

def LowerUniqueLabelsPass : Pass<"rtg-lower-unique-labels", "mlir::ModuleOp"> {
let summary = "lower label_unique_decl to label_decl operations";
let description = [{
This pass lowers label_unique_decl operations to label_decl operations by
creating a unique label string based on all the existing unique and
non-unique label declarations in the module.
}];

let statistics = [
Statistic<"numLabelsLowered", "num-labels-lowered",
"Number of unique labels lowered to regular label declarations.">,
];
}

#endif // CIRCT_DIALECT_RTG_TRANSFORMS_RTGPASSES_TD
2 changes: 2 additions & 0 deletions lib/Dialect/RTG/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
add_circt_dialect_library(CIRCTRTGTransforms
ElaborationPass.cpp
EmitRTGISAAssemblyPass.cpp
InlineSequencesPass.cpp
LinearScanRegisterAllocationPass.cpp
LowerUniqueLabelsPass.cpp

DEPENDS
CIRCTRTGTransformsIncGen
Expand Down
188 changes: 109 additions & 79 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ namespace {
struct BagStorage;
struct SequenceStorage;
struct RandomizedSequenceStorage;
struct InterleavedSequenceStorage;
struct SetStorage;
struct VirtualRegisterStorage;
struct UniqueLabelStorage;
Expand All @@ -107,8 +108,9 @@ struct LabelValue {
/// The abstract base class for elaborated values.
using ElaboratorValue =
std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
RandomizedSequenceStorage *, SetStorage *,
VirtualRegisterStorage *, UniqueLabelStorage *, LabelValue>;
RandomizedSequenceStorage *, InterleavedSequenceStorage *,
SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
LabelValue>;

// NOLINTNEXTLINE(readability-identifier-naming)
llvm::hash_code hash_value(const LabelValue &val) {
Expand Down Expand Up @@ -309,6 +311,34 @@ struct RandomizedSequenceStorage {
const SequenceStorage *sequence;
};

/// Storage object for interleaved '!rtg.randomized_sequence'es.
struct InterleavedSequenceStorage {
InterleavedSequenceStorage(SmallVector<ElaboratorValue> &&sequences,
uint32_t batchSize)
: sequences(std::move(sequences)), batchSize(batchSize),
hashcode(llvm::hash_combine(
llvm::hash_combine_range(sequences.begin(), sequences.end()),
batchSize)) {}

explicit InterleavedSequenceStorage(RandomizedSequenceStorage *sequence)
: sequences(SmallVector<ElaboratorValue>(1, sequence)), batchSize(1),
hashcode(llvm::hash_combine(
llvm::hash_combine_range(sequences.begin(), sequences.end()),
batchSize)) {}

bool isEqual(const InterleavedSequenceStorage *other) const {
return hashcode == other->hashcode && sequences == other->sequences &&
batchSize == other->batchSize;
}

const SmallVector<ElaboratorValue> sequences;

const uint32_t batchSize;

// The cached hashcode to avoid repeated computations.
const unsigned hashcode;
};

/// Represents a unique virtual register.
struct VirtualRegisterStorage {
VirtualRegisterStorage(ArrayAttr allowedRegs) : allowedRegs(allowedRegs) {}
Expand Down Expand Up @@ -373,6 +403,8 @@ class Internalizer {
return internedSequences;
else if constexpr (std::is_same_v<StorageTy, RandomizedSequenceStorage>)
return internedRandomizedSequences;
else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
return internedInterleavedSequences;
else
static_assert(!sizeof(StorageTy),
"no intern set available for this storage type.");
Expand All @@ -392,6 +424,9 @@ class Internalizer {
DenseSet<HashedStorage<RandomizedSequenceStorage>,
StorageKeyInfo<RandomizedSequenceStorage>>
internedRandomizedSequences;
DenseSet<HashedStorage<InterleavedSequenceStorage>,
StorageKeyInfo<InterleavedSequenceStorage>>
internedInterleavedSequences;
};

} // namespace
Expand Down Expand Up @@ -438,6 +473,13 @@ static void print(RandomizedSequenceStorage *val, llvm::raw_ostream &os) {
os << ") at " << val << ">";
}

static void print(InterleavedSequenceStorage *val, llvm::raw_ostream &os) {
os << "<interleaved-sequence [";
llvm::interleaveComma(val->sequences, os,
[&](const ElaboratorValue &val) { os << val; });
os << "] batch-size " << val->batchSize << " at " << val << ">";
}

static void print(SetStorage *val, llvm::raw_ostream &os) {
os << "<set {";
llvm::interleaveComma(val->set, os,
Expand Down Expand Up @@ -677,7 +719,25 @@ class Materializer {
elabRequests.push(val);
Value seq = builder.create<GetSequenceOp>(
loc, SequenceType::get(builder.getContext(), {}), val->name);
return builder.create<RandomizeSequenceOp>(loc, seq);
Value res = builder.create<RandomizeSequenceOp>(loc, seq);
materializedValues[val] = res;
return res;
}

Value visit(InterleavedSequenceStorage *val, Location loc,
std::queue<RandomizedSequenceStorage *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
SmallVector<Value> sequences;
for (auto seqVal : val->sequences)
sequences.push_back(materialize(seqVal, loc, elabRequests, emitError));

if (sequences.size() == 1)
return sequences[0];

Value res =
builder.create<InterleaveSequencesOp>(loc, sequences, val->batchSize);
materializedValues[val] = res;
return res;
}

Value visit(VirtualRegisterStorage *val, Location loc,
Expand Down Expand Up @@ -735,7 +795,6 @@ struct ElaboratorSharedState {
SymbolTable &table;
std::mt19937 rng;
Namespace names;
Namespace labelNames;
Internalizer internalizer;

/// The worklist used to keep track of the test and sequence operations to
Expand Down Expand Up @@ -841,27 +900,57 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
auto *seq = get<SequenceStorage *>(op.getSequence());

auto name = sharedState.names.newName(seq->familyName.getValue());
state[op.getResult()] =
auto *randomizedSeq =
sharedState.internalizer.internalize<RandomizedSequenceStorage>(
name, currentContext, testState.name, seq);
state[op.getResult()] =
sharedState.internalizer.internalize<InterleavedSequenceStorage>(
randomizedSeq);
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
auto *seq = get<RandomizedSequenceStorage *>(op.getSequence());
if (seq->context != currentContext) {
auto err = op->emitError("attempting to place sequence ")
<< seq->name << " derived from "
<< seq->sequence->familyName.getValue() << " under context "
<< currentContext
<< ", but it was previously randomized for context ";
if (seq->context)
err << seq->context;
else
err << "'default'";
return err;
FailureOr<DeletionKind> visitOp(InterleaveSequencesOp op) {
SmallVector<ElaboratorValue> sequences;
for (auto seq : op.getSequences())
sequences.push_back(get<InterleavedSequenceStorage *>(seq));

state[op.getResult()] =
sharedState.internalizer.internalize<InterleavedSequenceStorage>(
std::move(sequences), op.getBatchSize());
return DeletionKind::Delete;
}

// NOLINTNEXTLINE(misc-no-recursion)
LogicalResult isValidContext(ElaboratorValue value, Operation *op) const {
if (std::holds_alternative<RandomizedSequenceStorage *>(value)) {
auto *seq = std::get<RandomizedSequenceStorage *>(value);
if (seq->context != currentContext) {
auto err = op->emitError("attempting to place sequence ")
<< seq->name << " derived from "
<< seq->sequence->familyName.getValue() << " under context "
<< currentContext
<< ", but it was previously randomized for context ";
if (seq->context)
err << seq->context;
else
err << "'default'";
return err;
}
return success();
}

auto *interVal = std::get<InterleavedSequenceStorage *>(value);
for (auto val : interVal->sequences)
if (failed(isValidContext(val, op)))
return failure();
return success();
}

FailureOr<DeletionKind> visitOp(EmbedSequenceOp op) {
auto *seqVal = get<InterleavedSequenceStorage *>(op.getSequence());
if (failed(isValidContext(seqVal, op)))
return failure();

return DeletionKind::Keep;
}

Expand Down Expand Up @@ -1039,7 +1128,6 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
FailureOr<DeletionKind> visitOp(LabelDeclOp op) {
auto substituted =
substituteFormatString(op.getFormatStringAttr(), op.getArgs());
sharedState.labelNames.add(substituted.getValue());
state[op.getLabel()] = LabelValue(substituted);
return DeletionKind::Delete;
}
Expand Down Expand Up @@ -1309,7 +1397,6 @@ struct ElaborationPass
void runOnOperation() override;
void cloneTargetsIntoTests(SymbolTable &table);
LogicalResult elaborateModule(ModuleOp moduleOp, SymbolTable &table);
LogicalResult inlineSequences(TestOp testOp, SymbolTable &table);
};
} // namespace

Expand Down Expand Up @@ -1407,6 +1494,8 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
auto seqOp = builder.cloneWithoutRegions(familyOp);
seqOp.getBodyRegion().emplaceBlock();
seqOp.setSymName(curr->name);
seqOp.setSequenceType(
SequenceType::get(builder.getContext(), ArrayRef<Type>{}));
table.insert(seqOp);
assert(seqOp.getSymName() == curr->name && "should not have been renamed");

Expand All @@ -1425,64 +1514,5 @@ LogicalResult ElaborationPass::elaborateModule(ModuleOp moduleOp,
materializer.finalize();
}

for (auto testOp : moduleOp.getOps<TestOp>()) {
// Inline all sequences and remove the operations that place the sequences.
if (failed(inlineSequences(testOp, table)))
return failure();

// Convert 'rtg.label_unique_decl' to 'rtg.label_decl' by choosing a unique
// name based on the set of names we collected during elaboration.
for (auto labelOp :
llvm::make_early_inc_range(testOp.getOps<LabelUniqueDeclOp>())) {
IRRewriter rewriter(labelOp);
auto newName = state.labelNames.newName(labelOp.getFormatString());
rewriter.replaceOpWithNewOp<LabelDeclOp>(labelOp, newName, ValueRange());
}
}

// Remove all sequences since they are not accessible from the outside and
// are not needed anymore since we fully inlined them.
for (auto seqOp : llvm::make_early_inc_range(moduleOp.getOps<SequenceOp>()))
seqOp->erase();

return success();
}

LogicalResult ElaborationPass::inlineSequences(TestOp testOp,
SymbolTable &table) {
OpBuilder builder(testOp);
for (auto iter = testOp.getBody()->begin();
iter != testOp.getBody()->end();) {
auto embedOp = dyn_cast<EmbedSequenceOp>(&*iter);
if (!embedOp) {
++iter;
continue;
}

auto randSeqOp = embedOp.getSequence().getDefiningOp<RandomizeSequenceOp>();
if (!randSeqOp)
return embedOp->emitError("sequence operand not directly defined by "
"'rtg.randomize_sequence' op");
auto getSeqOp = randSeqOp.getSequence().getDefiningOp<GetSequenceOp>();
if (!getSeqOp)
return randSeqOp->emitError(
"sequence operand not directly defined by 'rtg.get_sequence' op");

auto seqOp = table.lookup<SequenceOp>(getSeqOp.getSequenceAttr());

builder.setInsertionPointAfter(embedOp);
IRMapping mapping;
for (auto &op : *seqOp.getBody())
builder.clone(op, mapping);

(iter++)->erase();

if (randSeqOp->use_empty())
randSeqOp->erase();

if (getSeqOp->use_empty())
getSeqOp->erase();
}

return success();
}
Loading

0 comments on commit 0906f72

Please sign in to comment.