From 92199663d5bcc773c3802c14544a3d983d9cd086 Mon Sep 17 00:00:00 2001 From: max Date: Fri, 22 Dec 2023 19:55:40 -0600 Subject: [PATCH] fix transform --- mlir/extras/runtime/passes.py | 74 +--------- tests/test_transform.py | 268 +++++++++++++++++----------------- 2 files changed, 144 insertions(+), 198 deletions(-) diff --git a/mlir/extras/runtime/passes.py b/mlir/extras/runtime/passes.py index 06d765e..1ed76c2 100644 --- a/mlir/extras/runtime/passes.py +++ b/mlir/extras/runtime/passes.py @@ -210,79 +210,17 @@ def lower_to_vulkan(self, index_bitwidth=None): def transform_dialect_erase_schedule(self): return self.add_pass("test-transform-dialect-erase-schedule") - def transform_dialect_interpreter( + def transform_interpreter( self, - bind_first_extra_to_ops=None, - bind_first_extra_to_params=None, - bind_first_extra_to_results_of_ops=None, - bind_second_extra_to_ops=None, - bind_second_extra_to_params=None, - bind_second_extra_to_results_of_ops=None, debug_payload_root_tag=None, - debug_transform_root_tag=None, - enable_expensive_checks=None, - transform_file_name=None, - test_module_generation=None, + disable_expensive_checks=None, + entry_point=None, ): - if bind_first_extra_to_ops is not None and isinstance( - bind_first_extra_to_ops, (list, tuple) - ): - bind_first_extra_to_ops = ",".join(map(str, bind_first_extra_to_ops)) - if bind_first_extra_to_params is not None and isinstance( - bind_first_extra_to_params, (list, tuple) - ): - bind_first_extra_to_params = ",".join(map(str, bind_first_extra_to_params)) - if bind_first_extra_to_results_of_ops is not None and isinstance( - bind_first_extra_to_results_of_ops, (list, tuple) - ): - bind_first_extra_to_results_of_ops = ",".join( - map(str, bind_first_extra_to_results_of_ops) - ) - if bind_second_extra_to_ops is not None and isinstance( - bind_second_extra_to_ops, (list, tuple) - ): - bind_second_extra_to_ops = ",".join(map(str, bind_second_extra_to_ops)) - if bind_second_extra_to_params is not None and isinstance( - bind_second_extra_to_params, (list, tuple) - ): - bind_second_extra_to_params = ",".join( - map(str, bind_second_extra_to_params) - ) - if bind_second_extra_to_results_of_ops is not None and isinstance( - bind_second_extra_to_results_of_ops, (list, tuple) - ): - bind_second_extra_to_results_of_ops = ",".join( - map(str, bind_second_extra_to_results_of_ops) - ) - if debug_payload_root_tag is not None and isinstance( - debug_payload_root_tag, (list, tuple) - ): - debug_payload_root_tag = ",".join(map(str, debug_payload_root_tag)) - if debug_transform_root_tag is not None and isinstance( - debug_transform_root_tag, (list, tuple) - ): - debug_transform_root_tag = ",".join(map(str, debug_transform_root_tag)) - if transform_file_name is not None and isinstance( - transform_file_name, (list, tuple) - ): - transform_file_name = ",".join(map(str, transform_file_name)) - if test_module_generation is not None and isinstance( - test_module_generation, (list, tuple) - ): - test_module_generation = ",".join(map(str, test_module_generation)) return self.add_pass( - "test-transform-dialect-interpreter", - bind_first_extra_to_ops=bind_first_extra_to_ops, - bind_first_extra_to_params=bind_first_extra_to_params, - bind_first_extra_to_results_of_ops=bind_first_extra_to_results_of_ops, - bind_second_extra_to_ops=bind_second_extra_to_ops, - bind_second_extra_to_params=bind_second_extra_to_params, - bind_second_extra_to_results_of_ops=bind_second_extra_to_results_of_ops, + "transform-interpreter", debug_payload_root_tag=debug_payload_root_tag, - debug_transform_root_tag=debug_transform_root_tag, - enable_expensive_checks=enable_expensive_checks, - transform_file_name=transform_file_name, - test_module_generation=test_module_generation, + disable_expensive_checks=disable_expensive_checks, + entry_point=entry_point, ) ############################ diff --git a/tests/test_transform.py b/tests/test_transform.py index fa28692..9276200 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -10,7 +10,6 @@ apply_patterns_canonicalization, apply_cse, any_op_t, - FailurePropagationMode, ) from mlir.dialects.transform.extras import named_sequence, apply_patterns, sequence from mlir.dialects.transform.loop import loop_unroll @@ -32,6 +31,7 @@ tile, tile_to_scf_forall, ) +from mlir.extras.meta import region_op from mlir.extras.runtime.passes import run_pipeline, Pipeline # noinspection PyUnresolvedReferences @@ -53,7 +53,7 @@ def loop_unroll_op(): @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) def mod(): @named_sequence("__transform_main", [any_op_t()], []) - def basic(target: any_op_t()): + def basic(target): m = structured_match(any_op_t(), target, ops=["arith.addi"]) loop = get_parent_op(pdl.op_t(), m, op_name="scf.for") loop_unroll(loop, 4) @@ -83,7 +83,7 @@ def basic(target: any_op_t()): ) filecheck(correct, ctx.module) - mod = run_pipeline(ctx.module, Pipeline().add_pass("transform-interpreter")) + mod = run_pipeline(ctx.module, Pipeline().transform_interpreter()) correct = dedent( """\ @@ -130,10 +130,12 @@ def pad_(i: T.index(), j: T.index()): pad_tensor_3_4.emit() - @sequence([], FailurePropagationMode.Propagate, []) - def basic(target): - m = match(target, ["tensor.pad"]) - tiled_linalg_op, loops = tile(m, sizes=[2, 3]) + @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) + def mod(): + @named_sequence("__transform_main", [any_op_t()], []) + def basic(target): + m = match(target, ["tensor.pad"]) + tiled_linalg_op, loops = tile(m, sizes=[2, 3]) correct = dedent( """\ @@ -145,22 +147,21 @@ def basic(target): } : tensor<4x16xf32> to tensor<12x23xf32> return %padded : tensor<12x23xf32> } - transform.sequence failures(propagate) attributes {transform.target_tag = "basic"} { - ^bb0(%arg0: !pdl.operation): - %0 = transform.structured.match ops{["tensor.pad"]} in %arg0 : (!pdl.operation) -> !transform.any_op - %tiled_linalg_op, %loops:2 = transform.structured.tile_using_for %0[2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %0 = transform.structured.match ops{["tensor.pad"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %tiled_linalg_op, %loops:2 = transform.structured.tile_using_for %0[2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } } } """ ) filecheck(correct, ctx.module) - module = run_pipeline( + mod = run_pipeline( ctx.module, - Pipeline() - .add_pass("test-transform-dialect-interpreter") - .add_pass("test-transform-dialect-erase-schedule") - .canonicalize(), + Pipeline().transform_interpreter().canonicalize(), ) correct = dedent( """\ @@ -228,7 +229,7 @@ def basic(target): } """ ) - filecheck(correct, module) + filecheck(correct, mod) def test_linalg_tile(ctx: MLIRContext): @@ -243,10 +244,12 @@ def matmul( matmul.emit() - @sequence(target_tag="basic") - def basic(target): - m = match(target, ["linalg.matmul"]) - tiled_linalg_op, loops = tile(m, sizes=[2, 3]) + @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) + def mod(): + @named_sequence("__transform_main", [any_op_t()], []) + def basic(target): + m = match(target, ["linalg.matmul"]) + tiled_linalg_op, loops = tile(m, sizes=[2, 3]) correct = dedent( """\ @@ -255,22 +258,21 @@ def basic(target): %0 = linalg.matmul {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) outs(%arg2 : tensor<4x8xf32>) -> tensor<4x8xf32> return %0 : tensor<4x8xf32> } - transform.sequence failures(propagate) attributes {transform.target_tag = "basic"} { - ^bb0(%arg0: !pdl.operation): - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!pdl.operation) -> !transform.any_op - %tiled_linalg_op, %loops:2 = transform.structured.tile_using_for %0[2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %tiled_linalg_op, %loops:2 = transform.structured.tile_using_for %0[2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } } } """ ) filecheck(correct, ctx.module) - module = run_pipeline( + mod = run_pipeline( ctx.module, - Pipeline() - .add_pass("test-transform-dialect-interpreter") - .add_pass("test-transform-dialect-erase-schedule") - .canonicalize(), + Pipeline().transform_interpreter().canonicalize(), ) correct = dedent( @@ -300,7 +302,7 @@ def basic(target): } """ ) - filecheck(correct, module) + filecheck(correct, mod) def test_simple_matmul_tile_foreach_thread(ctx: MLIRContext): @@ -315,10 +317,12 @@ def matmul( matmul.emit() - @sequence(target_tag="basic") - def basic(target): - m = match(target, ["linalg.matmul"]) - tiled_linalg_op, loops = tile_to_scf_forall(m, tile_sizes=[2, 3]) + @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) + def mod(): + @named_sequence("__transform_main", [any_op_t()], []) + def basic(target): + m = match(target, ["linalg.matmul"]) + tiled_linalg_op, loops = tile_to_scf_forall(m, tile_sizes=[2, 3]) correct = dedent( """\ @@ -327,22 +331,21 @@ def basic(target): %0 = linalg.matmul {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) outs(%arg2 : tensor<4x8xf32>) -> tensor<4x8xf32> return %0 : tensor<4x8xf32> } - transform.sequence failures(propagate) attributes {transform.target_tag = "basic"} { - ^bb0(%arg0: !pdl.operation): - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!pdl.operation) -> !transform.any_op - %tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [2, 3] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } } } """ ) filecheck(correct, ctx.module) - module = run_pipeline( + mod = run_pipeline( ctx.module, - Pipeline() - .add_pass("test-transform-dialect-interpreter") - .add_pass("test-transform-dialect-erase-schedule") - .canonicalize(), + Pipeline().transform_interpreter().canonicalize(), ) correct = dedent( @@ -374,7 +377,7 @@ def basic(target): """ ) - filecheck(correct, module) + filecheck(correct, mod) def test_common_extension_sugar(ctx: MLIRContext): @@ -387,13 +390,15 @@ def select_cmp_eq_select(arg0: T.i64(), arg1: T.i64()): select_cmp_eq_select.emit() - @sequence(target_tag="basic") - def basic(target): - m = match(target, ["func.func"]) + @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) + def mod(): + @named_sequence("__transform_main", [any_op_t()], []) + def basic(target): + m = match(target, ["func.func"]) - @apply_patterns(m) - def pats(): - apply_patterns_canonicalization() + @apply_patterns(m) + def pats(): + apply_patterns_canonicalization() correct = dedent( """\ @@ -403,24 +408,23 @@ def pats(): %1 = arith.select %0, %arg0, %arg1 : i64 return %1 : i64 } - transform.sequence failures(propagate) attributes {transform.target_tag = "basic"} { - ^bb0(%arg0: !pdl.operation): - %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !transform.any_op - apply_patterns to %0 { - transform.apply_patterns.canonicalization - } : !transform.any_op + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.yield + } } } """ ) filecheck(correct, ctx.module) - module = run_pipeline( + mod = run_pipeline( ctx.module, - Pipeline() - .add_pass("test-transform-dialect-interpreter") - .add_pass("test-transform-dialect-erase-schedule") - .canonicalize(), + Pipeline().transform_interpreter().canonicalize(), ) correct = dedent( @@ -433,7 +437,7 @@ def pats(): """ ) - filecheck(correct, module) + filecheck(correct, mod) def test_apply_cse(ctx: MLIRContext): @@ -450,22 +454,24 @@ def matmul( matmul.emit() - @sequence(target_tag="basic") - def basic(variant_op): - matmul = match(variant_op, ["linalg.matmul"]) + @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) + def mod(): + @named_sequence("__transform_main", [any_op_t()], []) + def basic(variant_op): + matmul = match(variant_op, ["linalg.matmul"]) - forall_op, tiled_generic = tile_to_scf_forall( - matmul, tile_sizes=[2], mapping=[block_attr(MappingId.DimX)] - ) + forall_op, tiled_generic = tile_to_scf_forall( + matmul, tile_sizes=[2], mapping=[block_attr(MappingId.DimX)] + ) - top_func = match(variant_op, ["func.func"]) + top_func = match(variant_op, ["func.func"]) - @apply_patterns(top_func) - def pats(): - apply_patterns_canonicalization() + @apply_patterns(top_func) + def pats(): + apply_patterns_canonicalization() - top_func = match(variant_op, ["func.func"]) - apply_cse(top_func) + top_func = match(variant_op, ["func.func"]) + apply_cse(top_func) ctx.module.operation.verify() correct = dedent( @@ -475,28 +481,27 @@ def pats(): %0 = linalg.matmul {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<3x5xf32>, tensor<5x3xf32>) outs(%arg2 : tensor<3x3xf32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> } - transform.sequence failures(propagate) attributes {transform.target_tag = "basic"} { - ^bb0(%arg0: !pdl.operation): - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!pdl.operation) -> !transform.any_op - %tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [2](mapping = [#gpu.block]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - %1 = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !transform.any_op - apply_patterns to %1 { - transform.apply_patterns.canonicalization - } : !transform.any_op - %2 = transform.structured.match ops{["func.func"]} in %arg0 : (!pdl.operation) -> !transform.any_op - apply_cse to %2 : !transform.any_op + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [2](mapping = [#gpu.block]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %1 { + transform.apply_patterns.canonicalization + } : !transform.any_op + %2 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.apply_cse to %2 : !transform.any_op + transform.yield + } } } """ ) filecheck(correct, ctx.module) - module = run_pipeline( + mod = run_pipeline( ctx.module, - Pipeline() - .transform_dialect_interpreter() - .transform_dialect_erase_schedule() - .canonicalize(), + Pipeline().transform_interpreter().canonicalize(), ) correct = dedent( @@ -520,7 +525,7 @@ def pats(): } """ ) - filecheck(correct, module) + filecheck(correct, mod) def test_two_schedules(ctx: MLIRContext): @@ -539,31 +544,33 @@ def conv_2d_nhwc_hwcf( conv_2d_nhwc_hwcf.emit() - @sequence(target_tag="tile_outer") - def tile_outer(target): - m = match(target, ["linalg.conv_2d_nchw_fchw"]) - tiled = tile_to_scf_forall( - m, - tile_sizes=[0, 1, 8, 8], - mapping=[ - block_attr(MappingId.DimX), - block_attr(MappingId.DimY), - block_attr(MappingId.DimZ), - ], - ) - - @sequence(target_tag="tile_inner") - def tile_inner(target): - m = match(target, ["linalg.conv_2d_nchw_fchw"]) - tiled = tile_to_scf_forall( - m, - tile_sizes=[0, 1, 1, 1], - mapping=[ - thread_attr(MappingId.DimX), - thread_attr(MappingId.DimY), - thread_attr(MappingId.DimZ), - ], - ) + @module(attrs={"transform.with_named_sequence": UnitAttr.get()}) + def mod(): + @named_sequence("tile_outer", [any_op_t()], []) + def tile_outer(target): + m = match(target, ["linalg.conv_2d_nchw_fchw"]) + tiled = tile_to_scf_forall( + m, + tile_sizes=[0, 1, 8, 8], + mapping=[ + block_attr(MappingId.DimX), + block_attr(MappingId.DimY), + block_attr(MappingId.DimZ), + ], + ) + + @named_sequence("tile_inner", [any_op_t()], []) + def tile_inner(target): + m = match(target, ["linalg.conv_2d_nchw_fchw"]) + tiled = tile_to_scf_forall( + m, + tile_sizes=[0, 1, 1, 1], + mapping=[ + thread_attr(MappingId.DimX), + thread_attr(MappingId.DimY), + thread_attr(MappingId.DimZ), + ], + ) correct = dedent( """\ @@ -572,27 +579,28 @@ def tile_inner(target): %0 = linalg.conv_2d_nchw_fchw ins(%arg0, %arg1 : tensor<1x1x66x66xf32>, tensor<3x1x3x3xf32>) outs(%arg2 : tensor<1x3x64x64xf32>) -> tensor<1x3x64x64xf32> return %0 : tensor<1x3x64x64xf32> } - transform.sequence failures(propagate) attributes {transform.target_tag = "tile_outer"} { - ^bb0(%arg0: !pdl.operation): - %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg0 : (!pdl.operation) -> !transform.any_op - %tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [0, 1, 8, 8](mapping = [#gpu.block, #gpu.block, #gpu.block]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - } - transform.sequence failures(propagate) attributes {transform.target_tag = "tile_inner"} { - ^bb0(%arg0: !pdl.operation): - %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg0 : (!pdl.operation) -> !transform.any_op - %tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [0, 1, 1, 1](mapping = [#gpu.thread, #gpu.thread, #gpu.thread]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + module attributes {transform.with_named_sequence} { + transform.named_sequence @tile_outer(%arg0: !transform.any_op) { + %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [0, 1, 8, 8](mapping = [#gpu.block, #gpu.block, #gpu.block]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } + transform.named_sequence @tile_inner(%arg0: !transform.any_op) { + %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [0, 1, 1, 1](mapping = [#gpu.thread, #gpu.thread, #gpu.thread]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } } } """ ) filecheck(correct, ctx.module) - module = run_pipeline( + mod = run_pipeline( ctx.module, Pipeline() - .transform_dialect_interpreter(debug_transform_root_tag="tile_outer") - .transform_dialect_interpreter(debug_transform_root_tag="tile_inner") - .transform_dialect_erase_schedule() + .transform_interpreter(entry_point="tile_outer") + .transform_interpreter(entry_point="tile_inner") .canonicalize(), ) @@ -630,4 +638,4 @@ def tile_inner(target): """ ) - filecheck(correct, module) + filecheck(correct, mod)