diff --git a/mlir/extras/runtime/passes.py b/mlir/extras/runtime/passes.py index 9d4e69e..e1f8d14 100644 --- a/mlir/extras/runtime/passes.py +++ b/mlir/extras/runtime/passes.py @@ -4,7 +4,7 @@ import tempfile from contextlib import ExitStack from io import StringIO -from typing import Optional, List +from typing import Optional, List, Union from ...ir import StringAttr, Module from ...passmanager import PassManager @@ -26,7 +26,7 @@ def get_module_name_for_debug_dump(module): def run_pipeline( module, - pipeline: str, + pipeline: Union[str, "Pipeline"], description: Optional[str] = None, enable_ir_printing=False, print_pipeline=False, @@ -91,20 +91,19 @@ def __init__(self, pipeline=None, wrapper=None): pipeline = [] self._pipeline = pipeline - def Func(self, p: "Pipeline"): - assert isinstance(p, Pipeline) - self._pipeline.append(f"func.func({p.materialize(module=False)})") + def Context(self, context, p: "Pipeline"): + self._pipeline.append(f"{context}({p.materialize(module=False)})") return self + def Func(self, p: "Pipeline"): + return self.Context("func.func", p) + def Spirv(self, p: "Pipeline"): - assert isinstance(p, Pipeline) - self._pipeline.append(f"spirv.module({p.materialize(module=False)})") - return self + return self.Context("spirv.module", p) def Gpu(self, p: "Pipeline"): assert isinstance(p, Pipeline) - self._pipeline.append(f"gpu.module({p.materialize(module=False)})") - return self + return self.Context("gpu.module", p) def materialize(self, module=True): pipeline_str = ",".join(self._pipeline) @@ -116,6 +115,13 @@ def materialize(self, module=True): def __str__(self): return self.materialize() + def __iadd__(self, other: "Pipeline"): + self._pipeline.extend(other._pipeline) + return self + + def __add__(self, other: "Pipeline"): + return Pipeline(self._pipeline + other._pipeline) + def add_pass(self, pass_name, **kwargs): kwargs = { k.replace("_", "-"): int(v) if isinstance(v, bool) else v @@ -135,16 +141,26 @@ def lower_to_llvm_(self): def bufferize(self): return ( - self.Func(scf_bufferize().empty_tensor_to_alloc_tensor().linalg_bufferize()) + self.Func( + Pipeline() + .scf_bufferize() + .empty_tensor_to_alloc_tensor() + .linalg_bufferize() + ) .func_bufferize() .arith_bufferize() - .Func(tensor_bufferize().finalizing_bufferize().buffer_deallocation()) + .Func( + Pipeline() + .tensor_bufferize() + .finalizing_bufferize() + .buffer_deallocation() + ) ) def lower_to_llvm(self): return ( self.cse() - .Func(lower_affine().arith_expand().convert_math_to_llvm()) + .Func(Pipeline().lower_affine().arith_expand().convert_math_to_llvm()) .convert_math_to_libm() .expand_strided_metadata() .finalize_memref_to_llvm() @@ -152,7 +168,7 @@ def lower_to_llvm(self): .convert_cf_to_llvm() .cse() .lower_affine() - .Func(convert_arith_to_llvm()) + .Func(Pipeline().convert_arith_to_llvm()) .convert_func_to_llvm() .canonicalize() .convert_openmp_to_llvm() @@ -161,7 +177,7 @@ def lower_to_llvm(self): ) def lower_to_openmp(self): - return self.convert_scf_to_openmp().Func(lower_affine()) + return self.convert_scf_to_openmp().Func(Pipeline().lower_affine()) def sparse_compiler( self, @@ -198,10 +214,10 @@ def lower_to_vulkan(self, index_bitwidth=None): self.gpu_kernel_outlining() .fold_memref_alias_ops() .convert_gpu_to_spirv() - .Spirv(spirv_lower_abi_attrs().spirv_update_vce()) + .Spirv(Pipeline().spirv_lower_abi_attrs().spirv_update_vce()) .convert_gpu_launch_to_vulkan_launch() .finalize_memref_to_llvm() - .Func(llvm_request_c_wrappers()) + .Func(Pipeline().llvm_request_c_wrappers()) .convert_func_to_llvm(index_bitwidth=index_bitwidth) .reconcile_unrealized_casts() .launch_func_to_vulkan() @@ -3525,196 +3541,3 @@ def view_op_graph( print_result_types=print_result_types, ) return self - - -affine_data_copy_generate = Pipeline().affine_data_copy_generate -affine_expand_index_ops = Pipeline().affine_expand_index_ops -affine_loop_coalescing = Pipeline().affine_loop_coalescing -affine_loop_fusion = Pipeline().affine_loop_fusion -affine_loop_invariant_code_motion = Pipeline().affine_loop_invariant_code_motion -affine_loop_normalize = Pipeline().affine_loop_normalize -affine_loop_tile = Pipeline().affine_loop_tile -affine_loop_unroll = Pipeline().affine_loop_unroll -affine_loop_unroll_jam = Pipeline().affine_loop_unroll_jam -affine_parallelize = Pipeline().affine_parallelize -affine_pipeline_data_transfer = Pipeline().affine_pipeline_data_transfer -affine_scalrep = Pipeline().affine_scalrep -affine_simplify_structures = Pipeline().affine_simplify_structures -affine_super_vectorize = Pipeline().affine_super_vectorize -allocate_arm_sme_tiles = Pipeline().allocate_arm_sme_tiles -amdgpu_emulate_atomics = Pipeline().amdgpu_emulate_atomics -arith_bufferize = Pipeline().arith_bufferize -arith_emulate_unsupported_floats = Pipeline().arith_emulate_unsupported_floats -arith_emulate_wide_int = Pipeline().arith_emulate_wide_int -arith_expand = Pipeline().arith_expand -arith_int_narrowing = Pipeline().arith_int_narrowing -arith_unsigned_when_equivalent = Pipeline().arith_unsigned_when_equivalent -arm_neon_2d_to_intr = Pipeline().arm_neon_2d_to_intr -async_func_to_async_runtime = Pipeline().async_func_to_async_runtime -async_parallel_for = Pipeline().async_parallel_for -async_runtime_policy_based_ref_counting = ( - Pipeline().async_runtime_policy_based_ref_counting -) -async_runtime_ref_counting = Pipeline().async_runtime_ref_counting -async_runtime_ref_counting_opt = Pipeline().async_runtime_ref_counting_opt -async_to_async_runtime = Pipeline().async_to_async_runtime -buffer_deallocation = Pipeline().buffer_deallocation -buffer_hoisting = Pipeline().buffer_hoisting -buffer_loop_hoisting = Pipeline().buffer_loop_hoisting -buffer_results_to_out_params = Pipeline().buffer_results_to_out_params -bufferization_bufferize = Pipeline().bufferization_bufferize -canonicalize = Pipeline().canonicalize -control_flow_sink = Pipeline().control_flow_sink -convert_amdgpu_to_rocdl = Pipeline().convert_amdgpu_to_rocdl -convert_arith_to_llvm = Pipeline().convert_arith_to_llvm -convert_arith_to_spirv = Pipeline().convert_arith_to_spirv -convert_async_to_llvm = Pipeline().convert_async_to_llvm -convert_bufferization_to_memref = Pipeline().convert_bufferization_to_memref -convert_cf_to_llvm = Pipeline().convert_cf_to_llvm -convert_cf_to_spirv = Pipeline().convert_cf_to_spirv -convert_complex_to_libm = Pipeline().convert_complex_to_libm -convert_complex_to_llvm = Pipeline().convert_complex_to_llvm -convert_complex_to_spirv = Pipeline().convert_complex_to_spirv -convert_complex_to_standard = Pipeline().convert_complex_to_standard -convert_elementwise_to_linalg = Pipeline().convert_elementwise_to_linalg -convert_func_to_llvm = Pipeline().convert_func_to_llvm -convert_func_to_spirv = Pipeline().convert_func_to_spirv -convert_gpu_launch_to_vulkan_launch = Pipeline().convert_gpu_launch_to_vulkan_launch -convert_gpu_to_nvvm = Pipeline().convert_gpu_to_nvvm -convert_gpu_to_rocdl = Pipeline().convert_gpu_to_rocdl -convert_gpu_to_spirv = Pipeline().convert_gpu_to_spirv -convert_index_to_llvm = Pipeline().convert_index_to_llvm -convert_linalg_to_affine_loops = Pipeline().convert_linalg_to_affine_loops -convert_linalg_to_loops = Pipeline().convert_linalg_to_loops -convert_linalg_to_parallel_loops = Pipeline().convert_linalg_to_parallel_loops -convert_linalg_to_std = Pipeline().convert_linalg_to_std -convert_math_to_funcs = Pipeline().convert_math_to_funcs -convert_math_to_libm = Pipeline().convert_math_to_libm -convert_math_to_llvm = Pipeline().convert_math_to_llvm -convert_math_to_spirv = Pipeline().convert_math_to_spirv -convert_memref_to_spirv = Pipeline().convert_memref_to_spirv -convert_nvgpu_to_nvvm = Pipeline().convert_nvgpu_to_nvvm -convert_nvvm_to_llvm = Pipeline().convert_nvvm_to_llvm -convert_openacc_to_scf = Pipeline().convert_openacc_to_scf -convert_openmp_to_llvm = Pipeline().convert_openmp_to_llvm -convert_parallel_loops_to_gpu = Pipeline().convert_parallel_loops_to_gpu -convert_pdl_to_pdl_interp = Pipeline().convert_pdl_to_pdl_interp -convert_scf_to_cf = Pipeline().convert_scf_to_cf -convert_scf_to_openmp = Pipeline().convert_scf_to_openmp -convert_scf_to_spirv = Pipeline().convert_scf_to_spirv -convert_shape_constraints = Pipeline().convert_shape_constraints -convert_shape_to_std = Pipeline().convert_shape_to_std -convert_spirv_to_llvm = Pipeline().convert_spirv_to_llvm -convert_tensor_to_linalg = Pipeline().convert_tensor_to_linalg -convert_tensor_to_spirv = Pipeline().convert_tensor_to_spirv -convert_ub_to_llvm = Pipeline().convert_ub_to_llvm -convert_ub_to_spirv = Pipeline().convert_ub_to_spirv -convert_vector_to_arm_sme = Pipeline().convert_vector_to_arm_sme -convert_vector_to_gpu = Pipeline().convert_vector_to_gpu -convert_vector_to_llvm = Pipeline().convert_vector_to_llvm -convert_vector_to_scf = Pipeline().convert_vector_to_scf -convert_vector_to_spirv = Pipeline().convert_vector_to_spirv -cse = Pipeline().cse -decorate_spirv_composite_type_layout = Pipeline().decorate_spirv_composite_type_layout -drop_equivalent_buffer_results = Pipeline().drop_equivalent_buffer_results -duplicate_function_elimination = Pipeline().duplicate_function_elimination -eliminate_empty_tensors = Pipeline().eliminate_empty_tensors -empty_tensor_to_alloc_tensor = Pipeline().empty_tensor_to_alloc_tensor -enable_arm_streaming = Pipeline().enable_arm_streaming -ensure_debug_info_scope_on_llvm_func = Pipeline().ensure_debug_info_scope_on_llvm_func -expand_strided_metadata = Pipeline().expand_strided_metadata -finalize_memref_to_llvm = Pipeline().finalize_memref_to_llvm -finalizing_bufferize = Pipeline().finalizing_bufferize -fold_memref_alias_ops = Pipeline().fold_memref_alias_ops -fold_tensor_subset_ops = Pipeline().fold_tensor_subset_ops -func_bufferize = Pipeline().func_bufferize -generate_runtime_verification = Pipeline().generate_runtime_verification -gpu_async_region = Pipeline().gpu_async_region -gpu_kernel_outlining = Pipeline().gpu_kernel_outlining -gpu_launch_sink_index_computations = Pipeline().gpu_launch_sink_index_computations -gpu_map_parallel_loops = Pipeline().gpu_map_parallel_loops -gpu_to_llvm = Pipeline().gpu_to_llvm -inline = Pipeline().inline -int_range_optimizations = Pipeline().int_range_optimizations -launch_func_to_vulkan = Pipeline().launch_func_to_vulkan -linalg_bufferize = Pipeline().linalg_bufferize -linalg_fold_unit_extent_dims = Pipeline().linalg_fold_unit_extent_dims -linalg_fuse_elementwise_ops = Pipeline().linalg_fuse_elementwise_ops -linalg_generalize_named_ops = Pipeline().linalg_generalize_named_ops -linalg_inline_scalar_operands = Pipeline().linalg_inline_scalar_operands -linalg_named_op_conversion = Pipeline().linalg_named_op_conversion -llvm_legalize_for_export = Pipeline().llvm_legalize_for_export -llvm_optimize_for_nvvm_target = Pipeline().llvm_optimize_for_nvvm_target -llvm_request_c_wrappers = Pipeline().llvm_request_c_wrappers -llvm_type_consistency = Pipeline().llvm_type_consistency -loop_invariant_code_motion = Pipeline().loop_invariant_code_motion -lower_affine = Pipeline().lower_affine -lower_host_to_llvm = Pipeline().lower_host_to_llvm -lower_vector_mask = Pipeline().lower_vector_mask -map_memref_spirv_storage_class = Pipeline().map_memref_spirv_storage_class -math_uplift_to_fma = Pipeline().math_uplift_to_fma -mem2reg = Pipeline().mem2reg -memref_emulate_wide_int = Pipeline().memref_emulate_wide_int -memref_expand = Pipeline().memref_expand -normalize_memrefs = Pipeline().normalize_memrefs -nvgpu_optimize_shared_memory = Pipeline().nvgpu_optimize_shared_memory -one_shot_bufferize = Pipeline().one_shot_bufferize -opt_reduction_pass = Pipeline().opt_reduction_pass -outline_shape_computation = Pipeline().outline_shape_computation -post_sparsification_rewrite = Pipeline().post_sparsification_rewrite -pre_sparsification_rewrite = Pipeline().pre_sparsification_rewrite -print_ir = Pipeline().print_ir -print_op_stats = Pipeline().print_op_stats -promote_buffers_to_stack = Pipeline().promote_buffers_to_stack -reconcile_unrealized_casts = Pipeline().reconcile_unrealized_casts -reduction_tree = Pipeline().reduction_tree -remove_shape_constraints = Pipeline().remove_shape_constraints -resolve_ranked_shaped_type_result_dims = ( - Pipeline().resolve_ranked_shaped_type_result_dims -) -resolve_shaped_type_result_dims = Pipeline().resolve_shaped_type_result_dims -sccp = Pipeline().sccp -scf_bufferize = Pipeline().scf_bufferize -scf_for_loop_canonicalization = Pipeline().scf_for_loop_canonicalization -scf_for_loop_peeling = Pipeline().scf_for_loop_peeling -scf_for_loop_range_folding = Pipeline().scf_for_loop_range_folding -scf_for_loop_specialization = Pipeline().scf_for_loop_specialization -scf_for_to_while = Pipeline().scf_for_to_while -scf_parallel_loop_fusion = Pipeline().scf_parallel_loop_fusion -scf_parallel_loop_specialization = Pipeline().scf_parallel_loop_specialization -scf_parallel_loop_tiling = Pipeline().scf_parallel_loop_tiling -shape_bufferize = Pipeline().shape_bufferize -shape_to_shape_lowering = Pipeline().shape_to_shape_lowering -snapshot_op_locations = Pipeline().snapshot_op_locations -sparse_buffer_rewrite = Pipeline().sparse_buffer_rewrite -sparse_gpu_codegen = Pipeline().sparse_gpu_codegen -sparse_storage_specifier_to_llvm = Pipeline().sparse_storage_specifier_to_llvm -sparse_tensor_codegen = Pipeline().sparse_tensor_codegen -sparse_tensor_conversion = Pipeline().sparse_tensor_conversion -sparse_vectorization = Pipeline().sparse_vectorization -sparsification = Pipeline().sparsification -spirv_canonicalize_gl = Pipeline().spirv_canonicalize_gl -spirv_lower_abi_attrs = Pipeline().spirv_lower_abi_attrs -spirv_rewrite_inserts = Pipeline().spirv_rewrite_inserts -spirv_unify_aliased_resource = Pipeline().spirv_unify_aliased_resource -spirv_update_vce = Pipeline().spirv_update_vce -spirv_webgpu_prepare = Pipeline().spirv_webgpu_prepare -sroa = Pipeline().sroa -strip_debuginfo = Pipeline().strip_debuginfo -symbol_dce = Pipeline().symbol_dce -symbol_privatize = Pipeline().symbol_privatize -tensor_bufferize = Pipeline().tensor_bufferize -test_scf_parallel_loop_collapsing = Pipeline().test_scf_parallel_loop_collapsing -topological_sort = Pipeline().topological_sort -tosa_infer_shapes = Pipeline().tosa_infer_shapes -tosa_layerwise_constant_fold = Pipeline().tosa_layerwise_constant_fold -tosa_make_broadcastable = Pipeline().tosa_make_broadcastable -tosa_optional_decompositions = Pipeline().tosa_optional_decompositions -tosa_to_arith = Pipeline().tosa_to_arith -tosa_to_scf = Pipeline().tosa_to_scf -tosa_to_tensor = Pipeline().tosa_to_tensor -tosa_validate = Pipeline().tosa_validate -transform_dialect_check_uses = Pipeline().transform_dialect_check_uses -transform_infer_effects = Pipeline().transform_infer_effects -vector_bufferize = Pipeline().vector_bufferize -view_op_graph = Pipeline().view_op_graph diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1c3be2e..04e8709 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,10 +1,11 @@ -from mlir.extras.runtime.passes import cse, lower_affine, convert_arith_to_llvm +from mlir.extras.runtime.passes import Pipeline as pipe def test_basic(): p = ( - cse() - .Func(lower_affine().arith_expand().convert_math_to_llvm()) + pipe() + .cse() + .Func(pipe().lower_affine().arith_expand().convert_math_to_llvm()) .convert_math_to_libm() .expand_strided_metadata() .finalize_memref_to_llvm() @@ -12,7 +13,7 @@ def test_basic(): .convert_cf_to_llvm() .cse() .lower_affine() - .Func(convert_arith_to_llvm()) + .Func(pipe().convert_arith_to_llvm()) .convert_func_to_llvm() .canonicalize() .convert_openmp_to_llvm() @@ -23,3 +24,41 @@ def test_basic(): p.materialize() == "builtin.module(cse,func.func(lower-affine,arith-expand,convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,convert-scf-to-cf,convert-cf-to-llvm,cse,lower-affine,func.func(convert-arith-to-llvm),convert-func-to-llvm,canonicalize,convert-openmp-to-llvm,cse,reconcile-unrealized-casts)" ) + + p1 = ( + pipe() + .cse() + .Func(pipe().lower_affine().arith_expand().convert_math_to_llvm()) + .convert_math_to_libm() + .expand_strided_metadata() + .finalize_memref_to_llvm() + .convert_scf_to_cf() + .convert_cf_to_llvm() + ) + + p2 = ( + pipe() + .cse() + .lower_affine() + .Func(pipe().convert_arith_to_llvm()) + .convert_func_to_llvm() + .canonicalize() + .convert_openmp_to_llvm() + .cse() + .reconcile_unrealized_casts() + ) + + assert ( + p1 + p2 + ).materialize() == "builtin.module(cse,func.func(lower-affine,arith-expand,convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,convert-scf-to-cf,convert-cf-to-llvm,cse,lower-affine,func.func(convert-arith-to-llvm),convert-func-to-llvm,canonicalize,convert-openmp-to-llvm,cse,reconcile-unrealized-casts)" + + +def test_context(): + p = pipe().Context( + "aie.device", + pipe().add_pass("aie-localize-locks").add_pass("aie-normalize-address-spaces"), + ) + assert ( + str(p) + == "builtin.module(aie.device(aie-localize-locks,aie-normalize-address-spaces))" + )