From 5d0faef6b140fe6a8574b3fba7a8d65c5a37d20c Mon Sep 17 00:00:00 2001 From: makslevental Date: Mon, 15 Jan 2024 13:58:39 -0600 Subject: [PATCH] Pipeline improvements --- mlir/extras/runtime/passes.py | 8 ++++++-- tests/test_pipeline.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/mlir/extras/runtime/passes.py b/mlir/extras/runtime/passes.py index 9d4e69e..e8ddbcd 100644 --- a/mlir/extras/runtime/passes.py +++ b/mlir/extras/runtime/passes.py @@ -4,7 +4,7 @@ import tempfile from contextlib import ExitStack from io import StringIO -from typing import Optional, List +from typing import Optional, List, Union from ...ir import StringAttr, Module from ...passmanager import PassManager @@ -26,7 +26,7 @@ def get_module_name_for_debug_dump(module): def run_pipeline( module, - pipeline: str, + pipeline: Union[str, "Pipeline"], description: Optional[str] = None, enable_ir_printing=False, print_pipeline=False, @@ -116,6 +116,10 @@ def materialize(self, module=True): def __str__(self): return self.materialize() + def __add__(self, other: "Pipeline"): + self._pipeline.extend(other._pipeline) + return self + def add_pass(self, pass_name, **kwargs): kwargs = { k.replace("_", "-"): int(v) if isinstance(v, bool) else v diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 1c3be2e..58d2857 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -23,3 +23,28 @@ def test_basic(): p.materialize() == "builtin.module(cse,func.func(lower-affine,arith-expand,convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,convert-scf-to-cf,convert-cf-to-llvm,cse,lower-affine,func.func(convert-arith-to-llvm),convert-func-to-llvm,canonicalize,convert-openmp-to-llvm,cse,reconcile-unrealized-casts)" ) + + p1 = ( + cse() + .Func(lower_affine().arith_expand().convert_math_to_llvm()) + .convert_math_to_libm() + .expand_strided_metadata() + .finalize_memref_to_llvm() + .convert_scf_to_cf() + .convert_cf_to_llvm() + ) + + p2 = ( + cse() + .lower_affine() + .Func(convert_arith_to_llvm()) + .convert_func_to_llvm() + .canonicalize() + .convert_openmp_to_llvm() + .cse() + .reconcile_unrealized_casts() + ) + + assert ( + p1 + p2 + ).materialize() == "builtin.module(cse,func.func(lower-affine,arith-expand,convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,convert-scf-to-cf,convert-cf-to-llvm,cse,lower-affine,func.func(convert-arith-to-llvm),convert-func-to-llvm,canonicalize,convert-openmp-to-llvm,cse,reconcile-unrealized-casts)"