Skip to content

Commit

Permalink
[RTG] Custom parser/printer for sequence op and type for sequence fam…
Browse files Browse the repository at this point in the history
…ilies

Allow type parameters in the sequence type and attach it as a type attribute to the sequence op. That way ops referring to a sequence don't have to access the operation's body to verify the type.
  • Loading branch information
maerhart committed Jan 30, 2025
1 parent 1f4427a commit 1058927
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 56 deletions.
11 changes: 10 additions & 1 deletion include/circt-c/Dialect/RTG.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(RTG, rtg);
MLIR_CAPI_EXPORTED bool rtgTypeIsASequence(MlirType type);

/// Creates an RTG sequence type in the context.
MLIR_CAPI_EXPORTED MlirType rtgSequenceTypeGet(MlirContext ctxt);
MLIR_CAPI_EXPORTED MlirType rtgSequenceTypeGet(MlirContext ctxt,
intptr_t numElements,
MlirType const *elementTypes);

/// The number of substitution elements of the RTG sequence.
MLIR_CAPI_EXPORTED unsigned rtgSequenceTypeGetNumElements(MlirType type);

/// The type of of the substitution element at the given index.
MLIR_CAPI_EXPORTED MlirType rtgSequenceTypeGetElement(MlirType type,
unsigned i);

/// If the type is an RTG label.
MLIR_CAPI_EXPORTED bool rtgTypeIsALabel(MlirType type);
Expand Down
13 changes: 7 additions & 6 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ def SequenceOp : RTGOp<"sequence", [
stronger top-level isolation guarantees.
}];

let arguments = (ins SymbolNameAttr:$sym_name);
let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<SequenceType>:$sequenceType);

let regions = (region SizedRegion<1>:$bodyRegion);

let assemblyFormat = [{
$sym_name attr-dict-with-keyword $bodyRegion
}];
let hasCustomAssemblyFormat = 1;
let hasRegionVerifier = 1;
}

def SequenceClosureOp : RTGOp<"sequence_closure", [
Expand All @@ -77,7 +78,7 @@ def SequenceClosureOp : RTGOp<"sequence_closure", [
}];

let arguments = (ins SymbolNameAttr:$sequence, Variadic<AnyType>:$args);
let results = (outs SequenceType:$ref);
let results = (outs FullySubstitutedSequenceType:$ref);

let assemblyFormat = [{
$sequence (`(` $args^ `:` qualified(type($args)) `)`)? attr-dict
Expand All @@ -94,7 +95,7 @@ def InvokeSequenceOp : RTGOp<"invoke_sequence", []> {
were directly inlined relacing this operation.
}];

let arguments = (ins SequenceType:$sequence);
let arguments = (ins FullySubstitutedSequenceType:$sequence);

let assemblyFormat = "$sequence attr-dict";
}
Expand Down
19 changes: 15 additions & 4 deletions include/circt/Dialect/RTG/IR/RTGTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,27 @@ include "mlir/IR/AttrTypeBase.td"
class RTGTypeDef<string name> : TypeDef<RTGDialect, name>;

def SequenceType : RTGTypeDef<"Sequence"> {
let summary = "handle to a sequence closure";
let summary = "handle to a sequence or sequence family";
let description = [{
An SSA value of this type refers to an `rtg.sequence` operation and the
argument values it should be invoked with (if it has any).
An SSA value of this type refers to a sequence if the list of element types
is empty or a sequence family if there are elements left to be substituted.
}];

let parameters = (ins OptionalArrayRefParameter<
"mlir::Type", "element types">:$elementTypes);

let mnemonic = "sequence";
let assemblyFormat = "";
let assemblyFormat = "(`<` $elementTypes^ `>`)?";
}

def FullySubstitutedSequenceType : DialectType<RTGDialect,
CPred<"llvm::isa<rtg::SequenceType>($_self) && "
"llvm::cast<rtg::SequenceType>($_self).getElementTypes().empty()">,
"fully substituted sequence type", "::circt::rtg::SequenceType">,
BuildableType<
"::circt::rtg::SequenceType::get($_builder.getContext(), " #
"llvm::ArrayRef<::mlir::Type>{})">;

def LabelType : RTGTypeDef<"Label"> {
let summary = "a reference to a label";
let description = [{
Expand Down
18 changes: 10 additions & 8 deletions integration_test/Bindings/Python/dialects/rtg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,20 @@
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
seq = rtg.SequenceOp('seq')
setTy = rtg.SetType.get(rtg.SequenceType.get())
seq = rtg.SequenceOp('seq', TypeAttr.get(rtg.SequenceType.get([setTy])))
seqBlock = Block.create_at_start(seq.bodyRegion, [setTy])

# CHECK: rtg.sequence @seq {
# CHECK: ^bb{{.*}}(%{{.*}}: !rtg.set<!rtg.sequence>):
# CHECK: rtg.sequence @seq(%{{.*}}: !rtg.set<!rtg.sequence>) {
# CHECK: }
print(m)

with Context() as ctx, Location.unknown():
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
seq = rtg.SequenceOp('sequence_name')
seq = rtg.SequenceOp('sequence_name',
TypeAttr.get(rtg.SequenceType.get([])))
Block.create_at_start(seq.bodyRegion, [])

test = rtg.TestOp('test_name', TypeAttr.get(rtg.DictType.get()))
Expand Down Expand Up @@ -89,12 +89,14 @@
setTy = rtg.SetType.get(indexTy)
bagTy = rtg.BagType.get(indexTy)
ireg = rtgtest.IntegerRegisterType.get()
seq = rtg.SequenceOp('seq')
seq = rtg.SequenceOp(
'seq',
TypeAttr.get(
rtg.SequenceType.get([sequenceTy, labelTy, setTy, bagTy, ireg])))
Block.create_at_start(seq.bodyRegion,
[sequenceTy, labelTy, setTy, bagTy, ireg])

# CHECK: rtg.sequence @seq
# CHECK: (%{{.*}}: !rtg.sequence, %{{.*}}: !rtg.label, %{{.*}}: !rtg.set<index>, %{{.*}}: !rtg.bag<index>, %{{.*}}: !rtgtest.ireg):
# CHECK: rtg.sequence @seq(%{{.*}}: !rtg.sequence, %{{.*}}: !rtg.label, %{{.*}}: !rtg.set<index>, %{{.*}}: !rtg.bag<index>, %{{.*}}: !rtgtest.ireg)
print(m)

with Context() as ctx, Location.unknown():
Expand Down Expand Up @@ -189,7 +191,7 @@
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
seq = rtg.SequenceOp('seq')
seq = rtg.SequenceOp('seq', TypeAttr.get(rtg.SequenceType.get([])))
block = Block.create_at_start(seq.bodyRegion, [])
with InsertionPoint(block):
l = rtg.label_decl("label", [])
Expand Down
15 changes: 12 additions & 3 deletions lib/Bindings/Python/RTGModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,19 @@ void circt::python::populateDialectRTGSubmodule(py::module &m) {
mlir_type_subclass(m, "SequenceType", rtgTypeIsASequence)
.def_classmethod(
"get",
[](py::object cls, MlirContext ctxt) {
return cls(rtgSequenceTypeGet(ctxt));
[](py::object cls, std::vector<MlirType> &elementTypes,
MlirContext ctxt) {
return cls(rtgSequenceTypeGet(ctxt, elementTypes.size(),
elementTypes.data()));
},
py::arg("self"), py::arg("ctxt") = nullptr);
py::arg("self"), py::arg("elementTypes") = std::vector<MlirType>(),
py::arg("ctxt") = nullptr)
.def_property_readonly(
"num_elements",
[](MlirType self) { return rtgSequenceTypeGetNumElements(self); })
.def("get_element", [](MlirType self, unsigned i) {
return rtgSequenceTypeGetElement(self, i);
});

mlir_type_subclass(m, "LabelType", rtgTypeIsALabel)
.def_classmethod(
Expand Down
16 changes: 14 additions & 2 deletions lib/CAPI/Dialect/RTG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,20 @@ bool rtgTypeIsASequence(MlirType type) {
return isa<SequenceType>(unwrap(type));
}

MlirType rtgSequenceTypeGet(MlirContext ctxt) {
return wrap(SequenceType::get(unwrap(ctxt)));
MlirType rtgSequenceTypeGet(MlirContext ctxt, intptr_t numElements,
MlirType const *elementTypes) {
SmallVector<Type> types;
for (unsigned i = 0; i < numElements; ++i)
types.emplace_back(unwrap(elementTypes[i]));
return wrap(SequenceType::get(unwrap(ctxt), types));
}

unsigned rtgSequenceTypeGetNumElements(MlirType type) {
return cast<SequenceType>(unwrap(type)).getElementTypes().size();
}

MlirType rtgSequenceTypeGetElement(MlirType type, unsigned i) {
return wrap(cast<SequenceType>(unwrap(type)).getElementTypes()[i]);
}

// LabelType
Expand Down
74 changes: 73 additions & 1 deletion lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,77 @@ using namespace mlir;
using namespace circt;
using namespace rtg;

//===----------------------------------------------------------------------===//
// SequenceOp
//===----------------------------------------------------------------------===//

LogicalResult SequenceOp::verifyRegions() {
if (TypeRange(getSequenceType().getElementTypes()) !=
getBody()->getArgumentTypes())
return emitOpError("sequence type does not match block argument types");

return success();
}

ParseResult SequenceOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the name as a symbol.
if (parser.parseSymbolName(
result.getOrAddProperties<SequenceOp::Properties>().sym_name))
return failure();

// Parse the function signature.
SmallVector<OpAsmParser::Argument> arguments;
if (parser.parseArgumentList(arguments, OpAsmParser::Delimiter::Paren,
/*allowType=*/true, /*allowAttrs=*/true))
return failure();

SmallVector<Type> argTypes;
SmallVector<Location> argLocs;
argTypes.reserve(arguments.size());
argLocs.reserve(arguments.size());
for (auto &arg : arguments) {
argTypes.push_back(arg.type);
argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
}
Type type = SequenceType::get(result.getContext(), argTypes);
result.getOrAddProperties<SequenceOp::Properties>().sequenceType =
TypeAttr::get(type);

auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
return parser.emitError(loc)
<< "'" << result.name.getStringRef() << "' op ";
})))
return failure();

std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
if (parser.parseRegion(*bodyRegionRegion, arguments))
return failure();

if (bodyRegionRegion->empty()) {
bodyRegionRegion->emplaceBlock();
bodyRegionRegion->addArguments(argTypes, argLocs);
}
result.addRegion(std::move(bodyRegionRegion));

return success();
}

void SequenceOp::print(OpAsmPrinter &p) {
p << ' ';
p.printSymbolName(getSymNameAttr().getValue());
p << "(";
llvm::interleaveComma(getBody()->getArguments(), p,
[&](auto arg) { p.printRegionArgument(arg); });
p << ")";
p.printOptionalAttrDictWithKeyword(
(*this)->getAttrs(), {getSymNameAttrName(), getSequenceTypeAttrName()});
p << ' ';
p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
}

//===----------------------------------------------------------------------===//
// SequenceClosureOp
//===----------------------------------------------------------------------===//
Expand All @@ -31,7 +102,8 @@ SequenceClosureOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
<< "'" << getSequence()
<< "' does not reference a valid 'rtg.sequence' operation";

if (seq.getBodyRegion().getArgumentTypes() != getArgs().getTypes())
if (TypeRange(seq.getSequenceType().getElementTypes()) !=
getArgs().getTypes())
return emitOpError("referenced 'rtg.sequence' op's argument types must "
"match 'args' types");

Expand Down
2 changes: 1 addition & 1 deletion test/CAPI/rtg-pipelines.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ int main(int argc, char **argv) {
mlirDialectHandleRegisterDialect(mlirGetDialectHandle__rtg__(), ctx);

MlirModule moduleOp = mlirModuleCreateParse(
ctx, mlirStringRefCreateFromCString("rtg.sequence @seq {\n"
ctx, mlirStringRefCreateFromCString("rtg.sequence @seq() {\n"
"}\n"
"rtg.test @test : !rtg.dict<> {\n"
" %0 = rtg.sequence_closure @seq\n"
Expand Down
13 changes: 12 additions & 1 deletion test/CAPI/rtg.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,24 @@
#include "mlir-c/BuiltinTypes.h"

static void testSequenceType(MlirContext ctx) {
MlirType sequenceTy = rtgSequenceTypeGet(ctx);
MlirType sequenceTy = rtgSequenceTypeGet(ctx, 0, NULL);

// CHECK: is_sequence
fprintf(stderr, rtgTypeIsASequence(sequenceTy) ? "is_sequence\n"
: "isnot_sequence\n");
// CHECK: !rtg.sequence
mlirTypeDump(sequenceTy);

MlirType sequenceWithArgsTy = rtgSequenceTypeGet(ctx, 1, &sequenceTy);
// CHECK: is_sequence
fprintf(stderr, rtgTypeIsASequence(sequenceWithArgsTy) ? "is_sequence\n"
: "isnot_sequence\n");
// CHECK: 1
fprintf(stderr, "%d\n", rtgSequenceTypeGetNumElements(sequenceWithArgsTy));
// CHECK: !rtg.sequence
mlirTypeDump(rtgSequenceTypeGetElement(sequenceWithArgsTy, 0));
// CHECK: !rtg.sequence<!rtg.sequence>
mlirTypeDump(sequenceWithArgsTy);
}

static void testLabelType(MlirContext ctx) {
Expand Down
20 changes: 10 additions & 10 deletions test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// RUN: circt-opt %s | FileCheck %s
// RUN: circt-opt %s --verify-roundtrip | FileCheck %s

// CHECK-LABEL: rtg.sequence @seq
// CHECK-SAME: attributes {rtg.some_attr} {
rtg.sequence @seq0 attributes {rtg.some_attr} {
rtg.sequence @seq0() {
%arg = arith.constant 1 : index
// CHECK: [[LBL0:%.*]] = rtg.label_decl "label_string_{0}_{1}", %{{.*}}, %{{.*}}
%0 = rtg.label_decl "label_string_{0}_{1}", %arg, %arg
Expand All @@ -16,14 +15,16 @@ rtg.sequence @seq0 attributes {rtg.some_attr} {
rtg.label external %0
}

// CHECK-LABEL: rtg.sequence @seqAttrsAndTypeElements
// CHECK-SAME: (%arg0: !rtg.sequence<!rtg.sequence<!rtg.label, !rtg.set<index>>>) attributes {rtg.some_attr} {
rtg.sequence @seqAttrsAndTypeElements(%arg0: !rtg.sequence<!rtg.sequence<!rtg.label, !rtg.set<index>>>) attributes {rtg.some_attr} {}

// CHECK-LABEL: rtg.sequence @seq1
// CHECK: ^bb0(%arg0: i32, %arg1: !rtg.sequence):
rtg.sequence @seq1 {
^bb0(%arg0: i32, %arg1: !rtg.sequence):
}
// CHECK-SAME: (%arg0: i32, %arg1: !rtg.sequence)
rtg.sequence @seq1(%arg0: i32, %arg1: !rtg.sequence) { }

// CHECK-LABEL: rtg.sequence @invocations
rtg.sequence @invocations {
rtg.sequence @invocations() {
// CHECK: [[V0:%.+]] = rtg.sequence_closure @seq0
// CHECK: [[C0:%.+]] = arith.constant 0 : i32
// CHECK: [[V1:%.+]] = rtg.sequence_closure @seq1([[C0]], [[V0]] : i32, !rtg.sequence)
Expand Down Expand Up @@ -55,8 +56,7 @@ func.func @sets(%arg0: i32, %arg1: i32) {
}

// CHECK-LABEL: @bags
rtg.sequence @bags {
^bb0(%arg0: i32, %arg1: i32, %arg2: index):
rtg.sequence @bags(%arg0: i32, %arg1: i32, %arg2: index) {
// CHECK: [[BAG:%.+]] = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32 {rtg.some_attr}
// CHECK: [[R:%.+]] = rtg.bag_select_random [[BAG]] : !rtg.bag<i32> {rtg.some_attr}
// CHECK: [[EMPTY:%.+]] = rtg.bag_create : i32
Expand Down
2 changes: 1 addition & 1 deletion test/Dialect/RTG/IR/cse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// CHECK-LABEL: rtg.sequence @seq
// CHECK-SAME: attributes {rtg.some_attr} {
rtg.sequence @seq0 attributes {rtg.some_attr} {
rtg.sequence @seq0() attributes {rtg.some_attr} {
// CHECK-NEXT: arith.constant
%arg = arith.constant 1 : index
// CHECK-NEXT: rtg.label_decl "label_string_{0}_{1}", %{{.*}}, %{{.*}}
Expand Down
Loading

0 comments on commit 1058927

Please sign in to comment.