Skip to content

Commit

Permalink
Fix transform and scf reduce (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 23, 2023
1 parent f9b4c19 commit a6ac813
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
11 changes: 4 additions & 7 deletions mlir/extras/dialects/ext/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,10 @@ def reduce_(*operands, num_reductions=1):


@region_adder(terminator=lambda xs: reduce_return(*xs))
def reduce2(reduce_op):
return reduce_op.regions[1]


@region_adder(terminator=lambda xs: reduce_return(*xs))
def reduce3(reduce_op):
return reduce_op.regions[2]
def another_reduce(reduce_op):
for r in reduce_op.regions:
if len(r.blocks[0].operations) == 0:
return r


def yield_(*args):
Expand Down
9 changes: 4 additions & 5 deletions tests/test_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
while__,
while___,
placeholder_opaque_t,
reduce2,
reduce3,
another_reduce,
)
from mlir.extras.dialects.ext.tensor import empty, Tensor
from mlir.dialects.memref import alloca_scope_return
Expand Down Expand Up @@ -2573,7 +2572,7 @@ def test_parange_inits_with_for_with_two_reduce(ctx: MLIRContext):
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

@reduce2(res1)
@another_reduce(res1)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

Expand Down Expand Up @@ -2614,11 +2613,11 @@ def test_parange_inits_with_for_with_three_reduce(ctx: MLIRContext):
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

@reduce2(res1)
@another_reduce(res1)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

@reduce3(res1)
@another_reduce(res1)
def res2(lhs: T.index(), rhs: T.index()):
return lhs + rhs

Expand Down

0 comments on commit a6ac813

Please sign in to comment.