diff --git a/mlir/extras/dialects/ext/scf.py b/mlir/extras/dialects/ext/scf.py index 77f0189..3d6854e 100644 --- a/mlir/extras/dialects/ext/scf.py +++ b/mlir/extras/dialects/ext/scf.py @@ -292,7 +292,9 @@ def induction_variables(self): return self.body.arguments -parange_ = region_op(_parfor(ParallelOp), terminator=yield__) +parange_ = region_op( + _parfor(ParallelOp), terminator=lambda xs: reduce_return(xs[0]) if xs else None +) parange = _parfor_cm(ParallelOp) @@ -332,20 +334,30 @@ def while__(cond: Value, *, loc=None, ip=None): class ReduceOp(ReduceOp): - def __init__(self, operand, *, loc=None, ip=None): - super().__init__(operand, loc=loc, ip=ip) - self.regions[0].blocks.append(operand.type, operand.type) + def __init__(self, operands, num_reductions, *, loc=None, ip=None): + super().__init__(operands, num_reductions, loc=loc, ip=ip) + for i in range(num_reductions): + self.regions[i].blocks.append(operands[i].type, operands[i].type) -def reduce_(operand, *, loc=None, ip=None): - if loc is None: - loc = get_user_code_loc() - return ReduceOp(operand, loc=loc, ip=ip) +def reduce_(*operands, num_reductions=1): + loc = get_user_code_loc() + return ReduceOp(operands, num_reductions, loc=loc) reduce = region_op(reduce_, terminator=lambda xs: reduce_return(*xs)) +@region_adder(terminator=lambda xs: reduce_return(*xs)) +def reduce2(reduce_op): + return reduce_op.regions[1] + + +@region_adder(terminator=lambda xs: reduce_return(*xs)) +def reduce3(reduce_op): + return reduce_op.regions[2] + + def yield_(*args): if len(args) == 1 and isinstance(args[0], (list, OpResultList)): args = list(args[0]) diff --git a/tests/test_scf.py b/tests/test_scf.py index 8308e9d..86e1c1e 100644 --- a/tests/test_scf.py +++ b/tests/test_scf.py @@ -24,6 +24,8 @@ while__, while___, placeholder_opaque_t, + reduce2, + reduce3, ) from mlir.extras.dialects.ext.tensor import empty, Tensor from mlir.dialects.memref import alloca_scope_return @@ -2429,6 +2431,7 @@ def forfoo(i, j, shared_outs): filecheck(correct, ctx.module) +@pytest.mark.xfail def test_parange_no_inits(ctx: MLIRContext): ten = empty((10, 10), T.i32()) @@ -2494,12 +2497,12 @@ def test_forall_insert_slice_no_region_with_for(ctx: MLIRContext): filecheck(correct, ctx.module) +@pytest.mark.xfail def test_parange_no_inits_with_for(ctx: MLIRContext): ten = empty((10, 10), T.i32()) for i, j in parange([1, 1], [2, 2], [3, 3], inits=[]): one = constant(1.0) - yield_() ctx.module.operation.verify() correct = dedent( @@ -2535,8 +2538,6 @@ def res(lhs: Tensor, rhs: Tensor): assert isinstance(rhs, Tensor) return lhs + rhs - yield_() - ctx.module.operation.verify() correct = dedent( """\ @@ -2551,12 +2552,11 @@ def res(lhs: Tensor, rhs: Tensor): %1 = scf.parallel (%arg0, %arg1) = (%c1, %c1_0) to (%c2, %c2_1) step (%c3, %c3_2) init (%0) -> tensor<10x10xi32> { %cst = arith.constant 1.000000e+00 : f32 %2 = tensor.empty() : tensor<10x10xi32> - scf.reduce(%2) : tensor<10x10xi32> { + scf.reduce(%2 : tensor<10x10xi32>) { ^bb0(%arg2: tensor<10x10xi32>, %arg3: tensor<10x10xi32>): %3 = arith.addi %arg2, %arg3 : tensor<10x10xi32> scf.reduce.return %3 : tensor<10x10xi32> } - scf.yield } } """ @@ -2569,16 +2569,14 @@ def test_parange_inits_with_for_with_two_reduce(ctx: MLIRContext): for i, j in parange([1, 1], [2, 2], [3, 3], inits=[one, one]): - @reduce(i) + @reduce(i, j, num_reductions=2) def res1(lhs: T.index(), rhs: T.index()): return lhs + rhs - @reduce(j) + @reduce2(res1) def res1(lhs: T.index(), rhs: T.index()): return lhs + rhs - yield_() - ctx.module.operation.verify() correct = dedent( """\ @@ -2591,17 +2589,67 @@ def res1(lhs: T.index(), rhs: T.index()): %c3 = arith.constant 3 : index %c3_3 = arith.constant 3 : index %0:2 = scf.parallel (%arg0, %arg1) = (%c1_0, %c1_1) to (%c2, %c2_2) step (%c3, %c3_3) init (%c1, %c1) -> (index, index) { - scf.reduce(%arg0) : index { + scf.reduce(%arg0, %arg1 : index, index) { ^bb0(%arg2: index, %arg3: index): %1 = arith.addi %arg2, %arg3 : index scf.reduce.return %1 : index - } - scf.reduce(%arg1) : index { + }, { ^bb0(%arg2: index, %arg3: index): %1 = arith.addi %arg2, %arg3 : index scf.reduce.return %1 : index } - scf.yield + } + } + """ + ) + filecheck(correct, ctx.module) + + +def test_parange_inits_with_for_with_three_reduce(ctx: MLIRContext): + one = constant(1, index=True) + + for i, j, k in parange([1, 1, 1], [2, 2, 2], [3, 3, 3], inits=[one, one, one]): + + @reduce(i, j, k, num_reductions=3) + def res1(lhs: T.index(), rhs: T.index()): + return lhs + rhs + + @reduce2(res1) + def res1(lhs: T.index(), rhs: T.index()): + return lhs + rhs + + @reduce3(res1) + def res2(lhs: T.index(), rhs: T.index()): + return lhs + rhs + + ctx.module.operation.verify() + correct = dedent( + """\ + module { + %c1 = arith.constant 1 : index + %c1_0 = arith.constant 1 : index + %c1_1 = arith.constant 1 : index + %c1_2 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c2_3 = arith.constant 2 : index + %c2_4 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c3_5 = arith.constant 3 : index + %c3_6 = arith.constant 3 : index + %0:3 = scf.parallel (%arg0, %arg1, %arg2) = (%c1_0, %c1_1, %c1_2) to (%c2, %c2_3, %c2_4) step (%c3, %c3_5, %c3_6) init (%c1, %c1, %c1) -> (index, index, index) { + scf.reduce(%arg0, %arg1, %arg2 : index, index, index) { + ^bb0(%arg3: index, %arg4: index): + %1 = arith.addi %arg3, %arg4 : index + scf.reduce.return %1 : index + }, { + ^bb0(%arg3: index, %arg4: index): + %1 = arith.addi %arg3, %arg4 : index + scf.reduce.return %1 : index + }, { + ^bb0(%arg3: index, %arg4: index): + %1 = arith.addi %arg3, %arg4 : index + scf.reduce.return %1 : index + } } } """ diff --git a/tests/test_transform.py b/tests/test_transform.py index 9276200..e0b9f9e 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -31,7 +31,6 @@ tile, tile_to_scf_forall, ) -from mlir.extras.meta import region_op from mlir.extras.runtime.passes import run_pipeline, Pipeline # noinspection PyUnresolvedReferences