diff --git a/include/circt/Dialect/RTG/IR/RTGOps.td b/include/circt/Dialect/RTG/IR/RTGOps.td index a270137bf59d..9f48655ec47a 100644 --- a/include/circt/Dialect/RTG/IR/RTGOps.td +++ b/include/circt/Dialect/RTG/IR/RTGOps.td @@ -224,6 +224,23 @@ def BagDifferenceOp : RTGOp<"bag_difference", [ }]; } +def BagUnionOp : RTGOp<"bag_union", [ + Pure, SameOperandsAndResultType, Commutative +]> { + let summary = "computes the union of bags"; + let description = [{ + Computes the union of the given bags. The list of sets must contain at + least one element. + }]; + + let arguments = (ins Variadic:$bags); + let results = (outs BagType:$result); + + let assemblyFormat = [{ + $bags `:` qualified(type($result)) attr-dict + }]; +} + //===- Test Specification Operations --------------------------------------===// def TestOp : RTGOp<"test", [ diff --git a/include/circt/Dialect/RTG/IR/RTGVisitors.h b/include/circt/Dialect/RTG/IR/RTGVisitors.h index 4c4babb96edc..0423bb72f699 100644 --- a/include/circt/Dialect/RTG/IR/RTGVisitors.h +++ b/include/circt/Dialect/RTG/IR/RTGVisitors.h @@ -34,7 +34,7 @@ class RTGOpVisitor { .template Case( + BagDifferenceOp, BagUnionOp, TargetOp, YieldOp>( [&](auto expr) -> ResultType { return thisCast->visitOp(expr, args...); }) @@ -93,6 +93,7 @@ class RTGOpVisitor { HANDLE(BagCreateOp, Unhandled); HANDLE(BagSelectRandomOp, Unhandled); HANDLE(BagDifferenceOp, Unhandled); + HANDLE(BagUnionOp, Unhandled); HANDLE(TestOp, Unhandled); HANDLE(TargetOp, Unhandled); HANDLE(YieldOp, Unhandled); diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index b99ce9a1126c..a93068476b62 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -598,6 +598,20 @@ class Elaborator : public RTGOpVisitor, return DeletionKind::Delete; } + FailureOr + visitOp(BagUnionOp op, function_ref addToWorklist) { + MapVector result; + for (auto bag : op.getBags()) { + auto *val = cast(state.at(bag)); + for (auto [el, multiple] : val->getBag()) + result[el] += multiple; + } + + internalizeResult(op.getResult(), std::move(result), + op.getType()); + return DeletionKind::Delete; + } + FailureOr dispatchOpVisitor(Operation *op, function_ref addToWorklist) { diff --git a/test/Dialect/RTG/IR/basic.mlir b/test/Dialect/RTG/IR/basic.mlir index 85298992c4d8..ae49bbf9bb18 100644 --- a/test/Dialect/RTG/IR/basic.mlir +++ b/test/Dialect/RTG/IR/basic.mlir @@ -47,13 +47,15 @@ rtg.sequence @bags { // CHECK: [[BAG:%.+]] = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32 {rtg.some_attr} // CHECK: [[R:%.+]] = rtg.bag_select_random [[BAG]] : !rtg.bag {rtg.some_attr} // CHECK: [[EMPTY:%.+]] = rtg.bag_create : i32 - // CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] : !rtg.bag {rtg.some_attr} + // CHECK: [[DIFF:%.+]] = rtg.bag_difference [[BAG]], [[EMPTY]] : !rtg.bag {rtg.some_attr} // CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] inf : !rtg.bag + // CHECK: rtg.bag_union [[BAG]], [[EMPTY]], [[DIFF]] : !rtg.bag %bag = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32 {rtg.some_attr} %r = rtg.bag_select_random %bag : !rtg.bag {rtg.some_attr} %empty = rtg.bag_create : i32 %diff = rtg.bag_difference %bag, %empty : !rtg.bag {rtg.some_attr} %diff2 = rtg.bag_difference %bag, %empty inf : !rtg.bag + %union = rtg.bag_union %bag, %empty, %diff : !rtg.bag } // CHECK-LABEL: rtg.target @empty_target : !rtg.dict<> { diff --git a/test/Dialect/RTG/IR/errors.mlir b/test/Dialect/RTG/IR/errors.mlir index 2d5f147f234b..eb030cfbf7f8 100644 --- a/test/Dialect/RTG/IR/errors.mlir +++ b/test/Dialect/RTG/IR/errors.mlir @@ -73,3 +73,10 @@ rtg.sequence @seq { // expected-error @below {{expected 1 or more operands, but found 0}} rtg.set_union : !rtg.set } + +// ----- + +rtg.sequence @seq { + // expected-error @below {{expected 1 or more operands, but found 0}} + rtg.bag_union : !rtg.bag +} diff --git a/test/Dialect/RTG/Transform/elaboration.mlir b/test/Dialect/RTG/Transform/elaboration.mlir index d5545d17b952..5343404d02f6 100644 --- a/test/Dialect/RTG/Transform/elaboration.mlir +++ b/test/Dialect/RTG/Transform/elaboration.mlir @@ -38,10 +38,13 @@ rtg.test @bagOperations : !rtg.dict<> { // CHECK-NEXT: [[V5:%.+]] = rtg.bag_create ([[V1]] x [[V0]]) : i32 // CHECK-NEXT: func.call @dummy4([[V0]], [[V0]], [[V4]], [[V5]]) : %multiple = arith.constant 8 : index + %seven = arith.constant 7 : index %one = arith.constant 1 : index %0 = arith.constant 2 : i32 %1 = arith.constant 3 : i32 - %bag = rtg.bag_create (%multiple x %0, %multiple x %1) : i32 + %bag0 = rtg.bag_create (%seven x %0, %multiple x %1) : i32 + %bag1 = rtg.bag_create (%one x %0) : i32 + %bag = rtg.bag_union %bag0, %bag1 : !rtg.bag %2 = rtg.bag_select_random %bag : !rtg.bag {rtg.elaboration_custom_seed = 3} %new_bag = rtg.bag_create (%one x %2) : i32 %diff = rtg.bag_difference %bag, %new_bag : !rtg.bag