Skip to content

Commit

Permalink
fix paralle/reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Dec 23, 2023
1 parent 9219966 commit ef96e7b
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 22 deletions.
28 changes: 20 additions & 8 deletions mlir/extras/dialects/ext/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ def induction_variables(self):
return self.body.arguments


parange_ = region_op(_parfor(ParallelOp), terminator=yield__)
parange_ = region_op(
_parfor(ParallelOp), terminator=lambda xs: reduce_return(xs[0]) if xs else None
)
parange = _parfor_cm(ParallelOp)


Expand Down Expand Up @@ -332,20 +334,30 @@ def while__(cond: Value, *, loc=None, ip=None):


class ReduceOp(ReduceOp):
def __init__(self, operand, *, loc=None, ip=None):
super().__init__(operand, loc=loc, ip=ip)
self.regions[0].blocks.append(operand.type, operand.type)
def __init__(self, operands, num_reductions, *, loc=None, ip=None):
super().__init__(operands, num_reductions, loc=loc, ip=ip)
for i in range(num_reductions):
self.regions[i].blocks.append(operands[i].type, operands[i].type)


def reduce_(operand, *, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
return ReduceOp(operand, loc=loc, ip=ip)
def reduce_(*operands, num_reductions=1):
loc = get_user_code_loc()
return ReduceOp(operands, num_reductions, loc=loc)


reduce = region_op(reduce_, terminator=lambda xs: reduce_return(*xs))


@region_adder(terminator=lambda xs: reduce_return(*xs))
def reduce2(reduce_op):
return reduce_op.regions[1]


@region_adder(terminator=lambda xs: reduce_return(*xs))
def reduce3(reduce_op):
return reduce_op.regions[2]


def yield_(*args):
if len(args) == 1 and isinstance(args[0], (list, OpResultList)):
args = list(args[0])
Expand Down
74 changes: 61 additions & 13 deletions tests/test_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
while__,
while___,
placeholder_opaque_t,
reduce2,
reduce3,
)
from mlir.extras.dialects.ext.tensor import empty, Tensor
from mlir.dialects.memref import alloca_scope_return
Expand Down Expand Up @@ -2429,6 +2431,7 @@ def forfoo(i, j, shared_outs):
filecheck(correct, ctx.module)


@pytest.mark.xfail
def test_parange_no_inits(ctx: MLIRContext):
ten = empty((10, 10), T.i32())

Expand Down Expand Up @@ -2494,12 +2497,12 @@ def test_forall_insert_slice_no_region_with_for(ctx: MLIRContext):
filecheck(correct, ctx.module)


@pytest.mark.xfail
def test_parange_no_inits_with_for(ctx: MLIRContext):
ten = empty((10, 10), T.i32())

for i, j in parange([1, 1], [2, 2], [3, 3], inits=[]):
one = constant(1.0)
yield_()

ctx.module.operation.verify()
correct = dedent(
Expand Down Expand Up @@ -2535,8 +2538,6 @@ def res(lhs: Tensor, rhs: Tensor):
assert isinstance(rhs, Tensor)
return lhs + rhs

yield_()

ctx.module.operation.verify()
correct = dedent(
"""\
Expand All @@ -2551,12 +2552,11 @@ def res(lhs: Tensor, rhs: Tensor):
%1 = scf.parallel (%arg0, %arg1) = (%c1, %c1_0) to (%c2, %c2_1) step (%c3, %c3_2) init (%0) -> tensor<10x10xi32> {
%cst = arith.constant 1.000000e+00 : f32
%2 = tensor.empty() : tensor<10x10xi32>
scf.reduce(%2) : tensor<10x10xi32> {
scf.reduce(%2 : tensor<10x10xi32>) {
^bb0(%arg2: tensor<10x10xi32>, %arg3: tensor<10x10xi32>):
%3 = arith.addi %arg2, %arg3 : tensor<10x10xi32>
scf.reduce.return %3 : tensor<10x10xi32>
}
scf.yield
}
}
"""
Expand All @@ -2569,16 +2569,14 @@ def test_parange_inits_with_for_with_two_reduce(ctx: MLIRContext):

for i, j in parange([1, 1], [2, 2], [3, 3], inits=[one, one]):

@reduce(i)
@reduce(i, j, num_reductions=2)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

@reduce(j)
@reduce2(res1)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

yield_()

ctx.module.operation.verify()
correct = dedent(
"""\
Expand All @@ -2591,17 +2589,67 @@ def res1(lhs: T.index(), rhs: T.index()):
%c3 = arith.constant 3 : index
%c3_3 = arith.constant 3 : index
%0:2 = scf.parallel (%arg0, %arg1) = (%c1_0, %c1_1) to (%c2, %c2_2) step (%c3, %c3_3) init (%c1, %c1) -> (index, index) {
scf.reduce(%arg0) : index {
scf.reduce(%arg0, %arg1 : index, index) {
^bb0(%arg2: index, %arg3: index):
%1 = arith.addi %arg2, %arg3 : index
scf.reduce.return %1 : index
}
scf.reduce(%arg1) : index {
}, {
^bb0(%arg2: index, %arg3: index):
%1 = arith.addi %arg2, %arg3 : index
scf.reduce.return %1 : index
}
scf.yield
}
}
"""
)
filecheck(correct, ctx.module)


def test_parange_inits_with_for_with_three_reduce(ctx: MLIRContext):
one = constant(1, index=True)

for i, j, k in parange([1, 1, 1], [2, 2, 2], [3, 3, 3], inits=[one, one, one]):

@reduce(i, j, k, num_reductions=3)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

@reduce2(res1)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

@reduce3(res1)
def res2(lhs: T.index(), rhs: T.index()):
return lhs + rhs

ctx.module.operation.verify()
correct = dedent(
"""\
module {
%c1 = arith.constant 1 : index
%c1_0 = arith.constant 1 : index
%c1_1 = arith.constant 1 : index
%c1_2 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c2_3 = arith.constant 2 : index
%c2_4 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c3_5 = arith.constant 3 : index
%c3_6 = arith.constant 3 : index
%0:3 = scf.parallel (%arg0, %arg1, %arg2) = (%c1_0, %c1_1, %c1_2) to (%c2, %c2_3, %c2_4) step (%c3, %c3_5, %c3_6) init (%c1, %c1, %c1) -> (index, index, index) {
scf.reduce(%arg0, %arg1, %arg2 : index, index, index) {
^bb0(%arg3: index, %arg4: index):
%1 = arith.addi %arg3, %arg4 : index
scf.reduce.return %1 : index
}, {
^bb0(%arg3: index, %arg4: index):
%1 = arith.addi %arg3, %arg4 : index
scf.reduce.return %1 : index
}, {
^bb0(%arg3: index, %arg4: index):
%1 = arith.addi %arg3, %arg4 : index
scf.reduce.return %1 : index
}
}
}
"""
Expand Down
1 change: 0 additions & 1 deletion tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
tile,
tile_to_scf_forall,
)
from mlir.extras.meta import region_op
from mlir.extras.runtime.passes import run_pipeline, Pipeline

# noinspection PyUnresolvedReferences
Expand Down

0 comments on commit ef96e7b

Please sign in to comment.