Skip to content

Commit

Permalink
fix paralle/reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 23, 2023
1 parent 9219966 commit 1c96345
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
23 changes: 15 additions & 8 deletions mlir/extras/dialects/ext/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ def induction_variables(self):
return self.body.arguments


parange_ = region_op(_parfor(ParallelOp), terminator=yield__)
parange_ = region_op(
_parfor(ParallelOp), terminator=lambda xs: reduce_return(xs[0]) if xs else None
)
parange = _parfor_cm(ParallelOp)


Expand Down Expand Up @@ -332,20 +334,25 @@ def while__(cond: Value, *, loc=None, ip=None):


class ReduceOp(ReduceOp):
def __init__(self, operand, *, loc=None, ip=None):
super().__init__(operand, loc=loc, ip=ip)
self.regions[0].blocks.append(operand.type, operand.type)
def __init__(self, operands, num_reductions, *, loc=None, ip=None):
super().__init__(operands, num_reductions, loc=loc, ip=ip)
for i in range(num_reductions):
self.regions[i].blocks.append(operands[i].type, operands[i].type)


def reduce_(operand, *, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
return ReduceOp(operand, loc=loc, ip=ip)
def reduce_(*operands, num_reductions=1):
loc = get_user_code_loc()
return ReduceOp(operands, num_reductions, loc=loc)


reduce = region_op(reduce_, terminator=lambda xs: reduce_return(*xs))


@region_adder(terminator=lambda xs: reduce_return(*xs))
def reduce2(reduce_op):
return reduce_op.regions[1]


def yield_(*args):
if len(args) == 1 and isinstance(args[0], (list, OpResultList)):
args = list(args[0])
Expand Down
21 changes: 8 additions & 13 deletions tests/test_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
while__,
while___,
placeholder_opaque_t,
reduce2,
)
from mlir.extras.dialects.ext.tensor import empty, Tensor
from mlir.dialects.memref import alloca_scope_return
Expand Down Expand Up @@ -2429,6 +2430,7 @@ def forfoo(i, j, shared_outs):
filecheck(correct, ctx.module)


@pytest.mark.xfail
def test_parange_no_inits(ctx: MLIRContext):
ten = empty((10, 10), T.i32())

Expand Down Expand Up @@ -2494,12 +2496,12 @@ def test_forall_insert_slice_no_region_with_for(ctx: MLIRContext):
filecheck(correct, ctx.module)


@pytest.mark.xfail
def test_parange_no_inits_with_for(ctx: MLIRContext):
ten = empty((10, 10), T.i32())

for i, j in parange([1, 1], [2, 2], [3, 3], inits=[]):
one = constant(1.0)
yield_()

ctx.module.operation.verify()
correct = dedent(
Expand Down Expand Up @@ -2535,8 +2537,6 @@ def res(lhs: Tensor, rhs: Tensor):
assert isinstance(rhs, Tensor)
return lhs + rhs

yield_()

ctx.module.operation.verify()
correct = dedent(
"""\
Expand All @@ -2551,12 +2551,11 @@ def res(lhs: Tensor, rhs: Tensor):
%1 = scf.parallel (%arg0, %arg1) = (%c1, %c1_0) to (%c2, %c2_1) step (%c3, %c3_2) init (%0) -> tensor<10x10xi32> {
%cst = arith.constant 1.000000e+00 : f32
%2 = tensor.empty() : tensor<10x10xi32>
scf.reduce(%2) : tensor<10x10xi32> {
scf.reduce(%2 : tensor<10x10xi32>) {
^bb0(%arg2: tensor<10x10xi32>, %arg3: tensor<10x10xi32>):
%3 = arith.addi %arg2, %arg3 : tensor<10x10xi32>
scf.reduce.return %3 : tensor<10x10xi32>
}
scf.yield
}
}
"""
Expand All @@ -2569,16 +2568,14 @@ def test_parange_inits_with_for_with_two_reduce(ctx: MLIRContext):

for i, j in parange([1, 1], [2, 2], [3, 3], inits=[one, one]):

@reduce(i)
@reduce(i, j, num_reductions=2)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

@reduce(j)
@reduce2(res1)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

yield_()

ctx.module.operation.verify()
correct = dedent(
"""\
Expand All @@ -2591,17 +2588,15 @@ def res1(lhs: T.index(), rhs: T.index()):
%c3 = arith.constant 3 : index
%c3_3 = arith.constant 3 : index
%0:2 = scf.parallel (%arg0, %arg1) = (%c1_0, %c1_1) to (%c2, %c2_2) step (%c3, %c3_3) init (%c1, %c1) -> (index, index) {
scf.reduce(%arg0) : index {
scf.reduce(%arg0, %arg1 : index, index) {
^bb0(%arg2: index, %arg3: index):
%1 = arith.addi %arg2, %arg3 : index
scf.reduce.return %1 : index
}
scf.reduce(%arg1) : index {
}, {
^bb0(%arg2: index, %arg3: index):
%1 = arith.addi %arg2, %arg3 : index
scf.reduce.return %1 : index
}
scf.yield
}
}
"""
Expand Down

0 comments on commit 1c96345

Please sign in to comment.