Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transform and scf reduce #38

Merged
merged 3 commits into from
Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions mlir/extras/dialects/ext/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ def induction_variables(self):
return self.body.arguments


parange_ = region_op(_parfor(ParallelOp), terminator=yield__)
parange_ = region_op(
_parfor(ParallelOp), terminator=lambda xs: reduce_return(xs[0]) if xs else None
)
parange = _parfor_cm(ParallelOp)


Expand Down Expand Up @@ -332,20 +334,30 @@ def while__(cond: Value, *, loc=None, ip=None):


class ReduceOp(ReduceOp):
def __init__(self, operand, *, loc=None, ip=None):
super().__init__(operand, loc=loc, ip=ip)
self.regions[0].blocks.append(operand.type, operand.type)
def __init__(self, operands, num_reductions, *, loc=None, ip=None):
super().__init__(operands, num_reductions, loc=loc, ip=ip)
for i in range(num_reductions):
self.regions[i].blocks.append(operands[i].type, operands[i].type)


def reduce_(operand, *, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
return ReduceOp(operand, loc=loc, ip=ip)
def reduce_(*operands, num_reductions=1):
loc = get_user_code_loc()
return ReduceOp(operands, num_reductions, loc=loc)


reduce = region_op(reduce_, terminator=lambda xs: reduce_return(*xs))


@region_adder(terminator=lambda xs: reduce_return(*xs))
def reduce2(reduce_op):
return reduce_op.regions[1]


@region_adder(terminator=lambda xs: reduce_return(*xs))
def reduce3(reduce_op):
return reduce_op.regions[2]


def yield_(*args):
if len(args) == 1 and isinstance(args[0], (list, OpResultList)):
args = list(args[0])
Expand Down
74 changes: 6 additions & 68 deletions mlir/extras/runtime/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,79 +210,17 @@ def lower_to_vulkan(self, index_bitwidth=None):
def transform_dialect_erase_schedule(self):
return self.add_pass("test-transform-dialect-erase-schedule")

def transform_dialect_interpreter(
def transform_interpreter(
self,
bind_first_extra_to_ops=None,
bind_first_extra_to_params=None,
bind_first_extra_to_results_of_ops=None,
bind_second_extra_to_ops=None,
bind_second_extra_to_params=None,
bind_second_extra_to_results_of_ops=None,
debug_payload_root_tag=None,
debug_transform_root_tag=None,
enable_expensive_checks=None,
transform_file_name=None,
test_module_generation=None,
disable_expensive_checks=None,
entry_point=None,
):
if bind_first_extra_to_ops is not None and isinstance(
bind_first_extra_to_ops, (list, tuple)
):
bind_first_extra_to_ops = ",".join(map(str, bind_first_extra_to_ops))
if bind_first_extra_to_params is not None and isinstance(
bind_first_extra_to_params, (list, tuple)
):
bind_first_extra_to_params = ",".join(map(str, bind_first_extra_to_params))
if bind_first_extra_to_results_of_ops is not None and isinstance(
bind_first_extra_to_results_of_ops, (list, tuple)
):
bind_first_extra_to_results_of_ops = ",".join(
map(str, bind_first_extra_to_results_of_ops)
)
if bind_second_extra_to_ops is not None and isinstance(
bind_second_extra_to_ops, (list, tuple)
):
bind_second_extra_to_ops = ",".join(map(str, bind_second_extra_to_ops))
if bind_second_extra_to_params is not None and isinstance(
bind_second_extra_to_params, (list, tuple)
):
bind_second_extra_to_params = ",".join(
map(str, bind_second_extra_to_params)
)
if bind_second_extra_to_results_of_ops is not None and isinstance(
bind_second_extra_to_results_of_ops, (list, tuple)
):
bind_second_extra_to_results_of_ops = ",".join(
map(str, bind_second_extra_to_results_of_ops)
)
if debug_payload_root_tag is not None and isinstance(
debug_payload_root_tag, (list, tuple)
):
debug_payload_root_tag = ",".join(map(str, debug_payload_root_tag))
if debug_transform_root_tag is not None and isinstance(
debug_transform_root_tag, (list, tuple)
):
debug_transform_root_tag = ",".join(map(str, debug_transform_root_tag))
if transform_file_name is not None and isinstance(
transform_file_name, (list, tuple)
):
transform_file_name = ",".join(map(str, transform_file_name))
if test_module_generation is not None and isinstance(
test_module_generation, (list, tuple)
):
test_module_generation = ",".join(map(str, test_module_generation))
return self.add_pass(
"test-transform-dialect-interpreter",
bind_first_extra_to_ops=bind_first_extra_to_ops,
bind_first_extra_to_params=bind_first_extra_to_params,
bind_first_extra_to_results_of_ops=bind_first_extra_to_results_of_ops,
bind_second_extra_to_ops=bind_second_extra_to_ops,
bind_second_extra_to_params=bind_second_extra_to_params,
bind_second_extra_to_results_of_ops=bind_second_extra_to_results_of_ops,
"transform-interpreter",
debug_payload_root_tag=debug_payload_root_tag,
debug_transform_root_tag=debug_transform_root_tag,
enable_expensive_checks=enable_expensive_checks,
transform_file_name=transform_file_name,
test_module_generation=test_module_generation,
disable_expensive_checks=disable_expensive_checks,
entry_point=entry_point,
)

############################
Expand Down
74 changes: 61 additions & 13 deletions tests/test_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
while__,
while___,
placeholder_opaque_t,
reduce2,
reduce3,
)
from mlir.extras.dialects.ext.tensor import empty, Tensor
from mlir.dialects.memref import alloca_scope_return
Expand Down Expand Up @@ -2429,6 +2431,7 @@ def forfoo(i, j, shared_outs):
filecheck(correct, ctx.module)


@pytest.mark.xfail
def test_parange_no_inits(ctx: MLIRContext):
ten = empty((10, 10), T.i32())

Expand Down Expand Up @@ -2494,12 +2497,12 @@ def test_forall_insert_slice_no_region_with_for(ctx: MLIRContext):
filecheck(correct, ctx.module)


@pytest.mark.xfail
def test_parange_no_inits_with_for(ctx: MLIRContext):
ten = empty((10, 10), T.i32())

for i, j in parange([1, 1], [2, 2], [3, 3], inits=[]):
one = constant(1.0)
yield_()

ctx.module.operation.verify()
correct = dedent(
Expand Down Expand Up @@ -2535,8 +2538,6 @@ def res(lhs: Tensor, rhs: Tensor):
assert isinstance(rhs, Tensor)
return lhs + rhs

yield_()

ctx.module.operation.verify()
correct = dedent(
"""\
Expand All @@ -2551,12 +2552,11 @@ def res(lhs: Tensor, rhs: Tensor):
%1 = scf.parallel (%arg0, %arg1) = (%c1, %c1_0) to (%c2, %c2_1) step (%c3, %c3_2) init (%0) -> tensor<10x10xi32> {
%cst = arith.constant 1.000000e+00 : f32
%2 = tensor.empty() : tensor<10x10xi32>
scf.reduce(%2) : tensor<10x10xi32> {
scf.reduce(%2 : tensor<10x10xi32>) {
^bb0(%arg2: tensor<10x10xi32>, %arg3: tensor<10x10xi32>):
%3 = arith.addi %arg2, %arg3 : tensor<10x10xi32>
scf.reduce.return %3 : tensor<10x10xi32>
}
scf.yield
}
}
"""
Expand All @@ -2569,16 +2569,14 @@ def test_parange_inits_with_for_with_two_reduce(ctx: MLIRContext):

for i, j in parange([1, 1], [2, 2], [3, 3], inits=[one, one]):

@reduce(i)
@reduce(i, j, num_reductions=2)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

@reduce(j)
@reduce2(res1)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

yield_()

ctx.module.operation.verify()
correct = dedent(
"""\
Expand All @@ -2591,17 +2589,67 @@ def res1(lhs: T.index(), rhs: T.index()):
%c3 = arith.constant 3 : index
%c3_3 = arith.constant 3 : index
%0:2 = scf.parallel (%arg0, %arg1) = (%c1_0, %c1_1) to (%c2, %c2_2) step (%c3, %c3_3) init (%c1, %c1) -> (index, index) {
scf.reduce(%arg0) : index {
scf.reduce(%arg0, %arg1 : index, index) {
^bb0(%arg2: index, %arg3: index):
%1 = arith.addi %arg2, %arg3 : index
scf.reduce.return %1 : index
}
scf.reduce(%arg1) : index {
}, {
^bb0(%arg2: index, %arg3: index):
%1 = arith.addi %arg2, %arg3 : index
scf.reduce.return %1 : index
}
scf.yield
}
}
"""
)
filecheck(correct, ctx.module)


def test_parange_inits_with_for_with_three_reduce(ctx: MLIRContext):
one = constant(1, index=True)

for i, j, k in parange([1, 1, 1], [2, 2, 2], [3, 3, 3], inits=[one, one, one]):

@reduce(i, j, k, num_reductions=3)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

@reduce2(res1)
def res1(lhs: T.index(), rhs: T.index()):
return lhs + rhs

@reduce3(res1)
def res2(lhs: T.index(), rhs: T.index()):
return lhs + rhs

ctx.module.operation.verify()
correct = dedent(
"""\
module {
%c1 = arith.constant 1 : index
%c1_0 = arith.constant 1 : index
%c1_1 = arith.constant 1 : index
%c1_2 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c2_3 = arith.constant 2 : index
%c2_4 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c3_5 = arith.constant 3 : index
%c3_6 = arith.constant 3 : index
%0:3 = scf.parallel (%arg0, %arg1, %arg2) = (%c1_0, %c1_1, %c1_2) to (%c2, %c2_3, %c2_4) step (%c3, %c3_5, %c3_6) init (%c1, %c1, %c1) -> (index, index, index) {
scf.reduce(%arg0, %arg1, %arg2 : index, index, index) {
^bb0(%arg3: index, %arg4: index):
%1 = arith.addi %arg3, %arg4 : index
scf.reduce.return %1 : index
}, {
^bb0(%arg3: index, %arg4: index):
%1 = arith.addi %arg3, %arg4 : index
scf.reduce.return %1 : index
}, {
^bb0(%arg3: index, %arg4: index):
%1 = arith.addi %arg3, %arg4 : index
scf.reduce.return %1 : index
}
}
}
"""
Expand Down
Loading
Loading