-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtest_pipeline.py
85 lines (76 loc) · 3 KB
/
test_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from mlir.extras.runtime.passes import Pipeline as pipe
def test_basic():
p = (
pipe()
.cse()
.Func(pipe().lower_affine().arith_expand().convert_math_to_llvm())
.convert_math_to_libm()
.expand_strided_metadata()
.finalize_memref_to_llvm()
.convert_scf_to_cf()
.convert_cf_to_llvm()
.cse()
.lower_affine()
.Func(pipe().convert_arith_to_llvm())
.convert_func_to_llvm()
.canonicalize()
.convert_openmp_to_llvm()
.cse()
.reconcile_unrealized_casts()
)
assert (
p.materialize()
== "builtin.module(cse,func.func(lower-affine,arith-expand,convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,convert-scf-to-cf,convert-cf-to-llvm,cse,lower-affine,func.func(convert-arith-to-llvm),convert-func-to-llvm,canonicalize,convert-openmp-to-llvm,cse,reconcile-unrealized-casts)"
)
p1 = (
pipe()
.cse()
.Func(pipe().lower_affine().arith_expand().convert_math_to_llvm())
.convert_math_to_libm()
.expand_strided_metadata()
.finalize_memref_to_llvm()
.convert_scf_to_cf()
.convert_cf_to_llvm()
)
p2 = (
pipe()
.cse()
.lower_affine()
.Func(pipe().convert_arith_to_llvm())
.convert_func_to_llvm()
.canonicalize()
.convert_openmp_to_llvm()
.cse()
.reconcile_unrealized_casts()
)
assert (
p1 + p2
).materialize() == "builtin.module(cse,func.func(lower-affine,arith-expand,convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,convert-scf-to-cf,convert-cf-to-llvm,cse,lower-affine,func.func(convert-arith-to-llvm),convert-func-to-llvm,canonicalize,convert-openmp-to-llvm,cse,reconcile-unrealized-casts)"
p1 = (
pipe()
.cse()
.Func(pipe().lower_affine().arith_expand().convert_math_to_llvm())
.convert_math_to_libm()
.expand_strided_metadata()
.finalize_memref_to_llvm()
.convert_scf_to_cf()
.convert_cf_to_llvm()
)
assert (
str(p1)
== "builtin.module(cse,func.func(lower-affine,arith-expand,convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,convert-scf-to-cf,convert-cf-to-llvm)"
)
p1 += p2
assert (
str(p1)
== "builtin.module(cse,func.func(lower-affine,arith-expand,convert-math-to-llvm),convert-math-to-libm,expand-strided-metadata,finalize-memref-to-llvm,convert-scf-to-cf,convert-cf-to-llvm,cse,lower-affine,func.func(convert-arith-to-llvm),convert-func-to-llvm,canonicalize,convert-openmp-to-llvm,cse,reconcile-unrealized-casts)"
)
def test_context():
p = pipe().Nested(
"aie.device",
pipe().add_pass("aie-localize-locks").add_pass("aie-normalize-address-spaces"),
)
assert (
str(p)
== "builtin.module(aie.device(aie-localize-locks,aie-normalize-address-spaces))"
)