Skip to content

Commit

Permalink
Pipeline improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental committed Jan 15, 2024
1 parent 185266f commit 5d0faef
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mlir/extras/runtime/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
from contextlib import ExitStack
from io import StringIO
from typing import Optional, List
from typing import Optional, List, Union

from ...ir import StringAttr, Module
from ...passmanager import PassManager
Expand All @@ -26,7 +26,7 @@ def get_module_name_for_debug_dump(module):

def run_pipeline(
module,
pipeline: str,
pipeline: Union[str, "Pipeline"],
description: Optional[str] = None,
enable_ir_printing=False,
print_pipeline=False,
Expand Down Expand Up @@ -116,6 +116,10 @@ def materialize(self, module=True):
def __str__(self):
return self.materialize()

def __add__(self, other: "Pipeline"):
self._pipeline.extend(other._pipeline)
return self

def add_pass(self, pass_name, **kwargs):
kwargs = {
k.replace("_", "-"): int(v) if isinstance(v, bool) else v
Expand Down
25 changes: 25 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,28 @@ def test_basic():
p.materialize()
== "builtin.module(cse,func.func(lower-affine,arith-expand,convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,convert-scf-to-cf,convert-cf-to-llvm,cse,lower-affine,func.func(convert-arith-to-llvm),convert-func-to-llvm,canonicalize,convert-openmp-to-llvm,cse,reconcile-unrealized-casts)"
)

p1 = (
cse()
.Func(lower_affine().arith_expand().convert_math_to_llvm())
.convert_math_to_libm()
.expand_strided_metadata()
.finalize_memref_to_llvm()
.convert_scf_to_cf()
.convert_cf_to_llvm()
)

p2 = (
cse()
.lower_affine()
.Func(convert_arith_to_llvm())
.convert_func_to_llvm()
.canonicalize()
.convert_openmp_to_llvm()
.cse()
.reconcile_unrealized_casts()
)

assert (
p1 + p2
).materialize() == "builtin.module(cse,func.func(lower-affine,arith-expand,convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,convert-scf-to-cf,convert-cf-to-llvm,cse,lower-affine,func.func(convert-arith-to-llvm),convert-func-to-llvm,canonicalize,convert-openmp-to-llvm,cse,reconcile-unrealized-casts)"

0 comments on commit 5d0faef

Please sign in to comment.