Skip to content

Commit 10ba6d5

Browse files
committed
[RTG] Add BagType and operations
1 parent fd08d90 commit 10ba6d5

File tree

6 files changed

+196
-5
lines changed

6 files changed

+196
-5
lines changed

include/circt/Dialect/RTG/IR/RTGOps.td

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,71 @@ def SetDifferenceOp : RTGOp<"set_difference", [
141141
}];
142142
}
143143

144+
//===- Bag Operations ------------------------------------------------------===//
145+
146+
def BagCreateOp : RTGOp<"bag_create", [Pure, SameVariadicOperandSize]> {
147+
let summary = "constructs a bag";
148+
let description = [{
149+
This operation constructs a bag with the provided values and associated
150+
multiples. This means the bag constructed in the following example contains
151+
two of each `%arg0` and `%arg0` (`{%arg0, %arg0, %arg1, %arg1}`).
152+
153+
```mlir
154+
%0 = arith.constant 2 : index
155+
%1 = rtg.bag_create (%0 x %arg0, %0 x %arg1) : i32
156+
```
157+
}];
158+
159+
let arguments = (ins Variadic<AnyType>:$elements,
160+
Variadic<Index>:$multiples);
161+
let results = (outs BagType:$bag);
162+
163+
let hasCustomAssemblyFormat = 1;
164+
let hasVerifier = 1;
165+
}
166+
167+
def BagSelectRandomOp : RTGOp<"bag_select_random", [
168+
Pure,
169+
TypesMatchWith<"output must be element type of input bag", "bag", "output",
170+
"llvm::cast<rtg::BagType>($_self).getElementType()">
171+
]> {
172+
let summary = "select a random element from the bag";
173+
let description = [{
174+
This operation returns an element from the bag selected uniformely at
175+
random. Therefore, the number of duplicates of each element can be used to
176+
bias the distribution.
177+
If the bag does not contain any elements, the behavior of this operation is
178+
undefined.
179+
}];
180+
181+
let arguments = (ins BagType:$bag);
182+
let results = (outs AnyType:$output);
183+
184+
let assemblyFormat = "$bag `:` qualified(type($bag)) attr-dict";
185+
}
186+
187+
def BagDifferenceOp : RTGOp<"bag_difference", [
188+
Pure,
189+
AllTypesMatch<["original", "diff", "output"]>
190+
]> {
191+
let summary = "computes the difference of two bags";
192+
let description = [{
193+
For each element the resulting bag will have as many fewer than the
194+
'original' bag as there are in the 'diff' bag. However, if the 'inf'
195+
attribute is attached, all elements of that kind will be removed (i.e., it
196+
is assumed the 'diff' bag has infinitely many copies of each element).
197+
}];
198+
199+
let arguments = (ins BagType:$original,
200+
BagType:$diff,
201+
UnitAttr:$inf);
202+
let results = (outs BagType:$output);
203+
204+
let assemblyFormat = [{
205+
$original `,` $diff (`inf` $inf^)? `:` qualified(type($output)) attr-dict
206+
}];
207+
}
208+
144209
//===- Test Specification Operations --------------------------------------===//
145210

146211
def TestOp : RTGOp<"test", [

include/circt/Dialect/RTG/IR/RTGTypes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,27 @@ def SetType : RTGTypeDef<"Set"> {
4343
let assemblyFormat = "`<` $elementType `>`";
4444
}
4545

46+
def BagType : RTGTypeDef<"Bag"> {
47+
let summary = "a bag of values";
48+
let description = [{
49+
This type represents a standard bag/multiset datastructure. It does not make
50+
any assumptions about the underlying implementation.
51+
}];
52+
53+
let parameters = (ins "::mlir::Type":$elementType);
54+
55+
let mnemonic = "bag";
56+
let assemblyFormat = "`<` $elementType `>`";
57+
}
58+
4659
class SetTypeOf<Type elementType> : ContainerType<
4760
elementType, SetType.predicate,
4861
"llvm::cast<rtg::SetType>($_self).getElementType()", "set">;
4962

63+
class BagTypeOf<Type elementType> : ContainerType<
64+
elementType, BagType.predicate,
65+
"llvm::cast<rtg::BagType>($_self).getElementType()", "bag">;
66+
5067
def DictType : RTGTypeDef<"Dict"> {
5168
let summary = "a dictionary";
5269
let description = [{

include/circt/Dialect/RTG/IR/RTGVisitors.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@ class RTGOpVisitor {
3131
auto *thisCast = static_cast<ConcreteType *>(this);
3232
return TypeSwitch<Operation *, ResultType>(op)
3333
.template Case<SequenceOp, SequenceClosureOp, SetCreateOp,
34-
SetSelectRandomOp, SetDifferenceOp, InvokeSequenceOp,
35-
TestOp, TargetOp, YieldOp>([&](auto expr) -> ResultType {
36-
return thisCast->visitOp(expr, args...);
37-
})
34+
SetSelectRandomOp, SetDifferenceOp, TestOp,
35+
InvokeSequenceOp, BagCreateOp, BagSelectRandomOp,
36+
BagDifferenceOp, TargetOp, YieldOp>(
37+
[&](auto expr) -> ResultType {
38+
return thisCast->visitOp(expr, args...);
39+
})
3840
.template Case<ContextResourceOpInterface>(
3941
[&](auto expr) -> ResultType {
4042
return thisCast->visitContextResourceOp(expr, args...);
@@ -79,6 +81,9 @@ class RTGOpVisitor {
7981
HANDLE(SetCreateOp, Unhandled);
8082
HANDLE(SetSelectRandomOp, Unhandled);
8183
HANDLE(SetDifferenceOp, Unhandled);
84+
HANDLE(BagCreateOp, Unhandled);
85+
HANDLE(BagSelectRandomOp, Unhandled);
86+
HANDLE(BagDifferenceOp, Unhandled);
8287
HANDLE(TestOp, Unhandled);
8388
HANDLE(TargetOp, Unhandled);
8489
HANDLE(YieldOp, Unhandled);
@@ -93,7 +98,7 @@ class RTGTypeVisitor {
9398
ResultType dispatchTypeVisitor(Type type, ExtraArgs... args) {
9499
auto *thisCast = static_cast<ConcreteType *>(this);
95100
return TypeSwitch<Type, ResultType>(type)
96-
.template Case<SequenceType, SetType, DictType>(
101+
.template Case<SequenceType, SetType, BagType, DictType>(
97102
[&](auto expr) -> ResultType {
98103
return thisCast->visitType(expr, args...);
99104
})
@@ -138,6 +143,7 @@ class RTGTypeVisitor {
138143

139144
HANDLE(SequenceType, Unhandled);
140145
HANDLE(SetType, Unhandled);
146+
HANDLE(BagType, Unhandled);
141147
HANDLE(DictType, Unhandled);
142148
#undef HANDLE
143149
};

lib/Dialect/RTG/IR/RTGOps.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,78 @@ LogicalResult SetCreateOp::verify() {
7878
return success();
7979
}
8080

81+
//===----------------------------------------------------------------------===//
82+
// BagCreateOp
83+
//===----------------------------------------------------------------------===//
84+
85+
ParseResult BagCreateOp::parse(OpAsmParser &parser, OperationState &result) {
86+
llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> elementOperands,
87+
multipleOperands;
88+
Type elemType;
89+
90+
if (!parser.parseOptionalLParen()) {
91+
while (true) {
92+
OpAsmParser::UnresolvedOperand elementOperand, multipleOperand;
93+
if (parser.parseOperand(multipleOperand) || parser.parseKeyword("x") ||
94+
parser.parseOperand(elementOperand))
95+
return failure();
96+
97+
elementOperands.push_back(elementOperand);
98+
multipleOperands.push_back(multipleOperand);
99+
100+
if (parser.parseOptionalComma()) {
101+
if (parser.parseRParen())
102+
return failure();
103+
break;
104+
}
105+
}
106+
}
107+
108+
if (parser.parseColon() || parser.parseType(elemType) ||
109+
parser.parseOptionalAttrDict(result.attributes))
110+
return failure();
111+
112+
result.addTypes({BagType::get(result.getContext(), elemType)});
113+
114+
for (auto operand : elementOperands)
115+
if (parser.resolveOperand(operand, elemType, result.operands))
116+
return failure();
117+
118+
for (auto operand : multipleOperands)
119+
if (parser.resolveOperand(operand, IndexType::get(result.getContext()),
120+
result.operands))
121+
return failure();
122+
123+
return success();
124+
}
125+
126+
void BagCreateOp::print(OpAsmPrinter &p) {
127+
p << " ";
128+
if (!getElements().empty())
129+
p << "(";
130+
llvm::interleaveComma(llvm::zip(getElements(), getMultiples()), p,
131+
[&](auto elAndMultiple) {
132+
auto [el, multiple] = elAndMultiple;
133+
p << multiple << " x " << el;
134+
});
135+
if (!getElements().empty())
136+
p << ")";
137+
138+
p << " : " << getBag().getType().getElementType();
139+
p.printOptionalAttrDict((*this)->getAttrs());
140+
}
141+
142+
LogicalResult BagCreateOp::verify() {
143+
if (!llvm::all_equal(getElements().getTypes()))
144+
return emitOpError() << "types of all elements must match";
145+
146+
if (getElements().size() > 0)
147+
if (getElements()[0].getType() != getBag().getType().getElementType())
148+
return emitOpError() << "operand types must match bag element type";
149+
150+
return success();
151+
}
152+
81153
//===----------------------------------------------------------------------===//
82154
// TestOp
83155
//===----------------------------------------------------------------------===//

test/Dialect/RTG/IR/basic.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ func.func @sets(%arg0: i32, %arg1: i32) {
3939
return
4040
}
4141

42+
// CHECK-LABEL: @bags
43+
rtg.sequence @bags {
44+
^bb0(%arg0: i32, %arg1: i32, %arg2: index):
45+
// CHECK: [[BAG:%.+]] = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32 {rtg.some_attr}
46+
// CHECK: [[R:%.+]] = rtg.bag_select_random [[BAG]] : !rtg.bag<i32> {rtg.some_attr}
47+
// CHECK: [[EMPTY:%.+]] = rtg.bag_create : i32
48+
// CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] : !rtg.bag<i32> {rtg.some_attr}
49+
// CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] inf : !rtg.bag<i32>
50+
%bag = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32 {rtg.some_attr}
51+
%r = rtg.bag_select_random %bag : !rtg.bag<i32> {rtg.some_attr}
52+
%empty = rtg.bag_create : i32
53+
%diff = rtg.bag_difference %bag, %empty : !rtg.bag<i32> {rtg.some_attr}
54+
%diff2 = rtg.bag_difference %bag, %empty inf : !rtg.bag<i32>
55+
}
56+
4257
// CHECK-LABEL: rtg.target @empty_target : !rtg.dict<> {
4358
// CHECK-NOT: rtg.yield
4459
rtg.target @empty_target : !rtg.dict<> {

test/Dialect/RTG/IR/errors.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,19 @@ rtg.test @test : !rtg.dict<b: i32, a: i32> {
5050
rtg.test @test : !rtg.dict<"": i32> {
5151
^bb0(%arg0: i32):
5252
}
53+
54+
// -----
55+
56+
rtg.sequence @seq {
57+
^bb0(%arg0: i32, %arg1: i64, %arg2: index):
58+
// expected-error @below {{types of all elements must match}}
59+
"rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i32, i64, index, index) -> !rtg.bag<i32>
60+
}
61+
62+
// -----
63+
64+
rtg.sequence @seq {
65+
^bb0(%arg0: i64, %arg1: i64, %arg2: index):
66+
// expected-error @below {{operand types must match bag element type}}
67+
"rtg.bag_create"(%arg0, %arg1, %arg2, %arg2){} : (i64, i64, index, index) -> !rtg.bag<i32>
68+
}

0 commit comments

Comments
 (0)