Skip to content

Commit

Permalink
[RTG] Add bag union operation (#7917)
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart authored Dec 9, 2024
1 parent 2aecdf7 commit daf1bda
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 3 deletions.
17 changes: 17 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,23 @@ def BagDifferenceOp : RTGOp<"bag_difference", [
}];
}

def BagUnionOp : RTGOp<"bag_union", [
Pure, SameOperandsAndResultType, Commutative
]> {
let summary = "computes the union of bags";
let description = [{
Computes the union of the given bags. The list of sets must contain at
least one element.
}];

let arguments = (ins Variadic<BagType>:$bags);
let results = (outs BagType:$result);

let assemblyFormat = [{
$bags `:` qualified(type($result)) attr-dict
}];
}

//===- Test Specification Operations --------------------------------------===//

def TestOp : RTGOp<"test", [
Expand Down
3 changes: 2 additions & 1 deletion include/circt/Dialect/RTG/IR/RTGVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class RTGOpVisitor {
.template Case<SequenceOp, SequenceClosureOp, SetCreateOp,
SetSelectRandomOp, SetDifferenceOp, SetUnionOp, TestOp,
InvokeSequenceOp, BagCreateOp, BagSelectRandomOp,
BagDifferenceOp, TargetOp, YieldOp>(
BagDifferenceOp, BagUnionOp, TargetOp, YieldOp>(
[&](auto expr) -> ResultType {
return thisCast->visitOp(expr, args...);
})
Expand Down Expand Up @@ -93,6 +93,7 @@ class RTGOpVisitor {
HANDLE(BagCreateOp, Unhandled);
HANDLE(BagSelectRandomOp, Unhandled);
HANDLE(BagDifferenceOp, Unhandled);
HANDLE(BagUnionOp, Unhandled);
HANDLE(TestOp, Unhandled);
HANDLE(TargetOp, Unhandled);
HANDLE(YieldOp, Unhandled);
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,20 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>,
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
visitOp(BagUnionOp op, function_ref<void(Operation *)> addToWorklist) {
MapVector<ElaboratorValue *, uint64_t> result;
for (auto bag : op.getBags()) {
auto *val = cast<BagValue>(state.at(bag));
for (auto [el, multiple] : val->getBag())
result[el] += multiple;
}

internalizeResult<BagValue>(op.getResult(), std::move(result),
op.getType());
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
dispatchOpVisitor(Operation *op,
function_ref<void(Operation *)> addToWorklist) {
Expand Down
4 changes: 3 additions & 1 deletion test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@ rtg.sequence @bags {
// CHECK: [[BAG:%.+]] = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32 {rtg.some_attr}
// CHECK: [[R:%.+]] = rtg.bag_select_random [[BAG]] : !rtg.bag<i32> {rtg.some_attr}
// CHECK: [[EMPTY:%.+]] = rtg.bag_create : i32
// CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] : !rtg.bag<i32> {rtg.some_attr}
// CHECK: [[DIFF:%.+]] = rtg.bag_difference [[BAG]], [[EMPTY]] : !rtg.bag<i32> {rtg.some_attr}
// CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] inf : !rtg.bag<i32>
// CHECK: rtg.bag_union [[BAG]], [[EMPTY]], [[DIFF]] : !rtg.bag<i32>
%bag = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32 {rtg.some_attr}
%r = rtg.bag_select_random %bag : !rtg.bag<i32> {rtg.some_attr}
%empty = rtg.bag_create : i32
%diff = rtg.bag_difference %bag, %empty : !rtg.bag<i32> {rtg.some_attr}
%diff2 = rtg.bag_difference %bag, %empty inf : !rtg.bag<i32>
%union = rtg.bag_union %bag, %empty, %diff : !rtg.bag<i32>
}

// CHECK-LABEL: rtg.target @empty_target : !rtg.dict<> {
Expand Down
7 changes: 7 additions & 0 deletions test/Dialect/RTG/IR/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,10 @@ rtg.sequence @seq {
// expected-error @below {{expected 1 or more operands, but found 0}}
rtg.set_union : !rtg.set<i32>
}

// -----

rtg.sequence @seq {
// expected-error @below {{expected 1 or more operands, but found 0}}
rtg.bag_union : !rtg.bag<i32>
}
5 changes: 4 additions & 1 deletion test/Dialect/RTG/Transform/elaboration.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,13 @@ rtg.test @bagOperations : !rtg.dict<> {
// CHECK-NEXT: [[V5:%.+]] = rtg.bag_create ([[V1]] x [[V0]]) : i32
// CHECK-NEXT: func.call @dummy4([[V0]], [[V0]], [[V4]], [[V5]]) :
%multiple = arith.constant 8 : index
%seven = arith.constant 7 : index
%one = arith.constant 1 : index
%0 = arith.constant 2 : i32
%1 = arith.constant 3 : i32
%bag = rtg.bag_create (%multiple x %0, %multiple x %1) : i32
%bag0 = rtg.bag_create (%seven x %0, %multiple x %1) : i32
%bag1 = rtg.bag_create (%one x %0) : i32
%bag = rtg.bag_union %bag0, %bag1 : !rtg.bag<i32>
%2 = rtg.bag_select_random %bag : !rtg.bag<i32> {rtg.elaboration_custom_seed = 3}
%new_bag = rtg.bag_create (%one x %2) : i32
%diff = rtg.bag_difference %bag, %new_bag : !rtg.bag<i32>
Expand Down

0 comments on commit daf1bda

Please sign in to comment.