Skip to content

Commit be84e1b

Browse files
committed
add test
1 parent 7da8f0f commit be84e1b

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

enzyme/test/MLIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_subdirectory(ActivityAnalysis)
22
add_subdirectory(AliasAnalysis)
33
add_subdirectory(Batch)
44
add_subdirectory(ForwardMode)
5+
add_subdirectory(OptimizeAD)
56
add_subdirectory(Passes)
67
add_subdirectory(ProbProg)
78
add_subdirectory(ReverseMode)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Run regression and unit tests
2+
add_lit_testsuite(check-enzymemlir-optimizead "Running MLIR OptimizeAD tests"
3+
${CMAKE_CURRENT_BINARY_DIR}
4+
DEPENDS enzymemlir-opt
5+
ARGS -v
6+
)
7+
8+
set_target_properties(check-enzymemlir-optimizead PROPERTIES FOLDER "Tests")
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: enzymemlir-opt --split-input-file --hoist-enzyme-regions %s | FileCheck %s
2+
// CHECK-LABEL: func.func @foo
3+
// CHECK-SAME: (%arg0: f64, %arg1: f64, %arg2: f64) -> f64
4+
// CHECK: %c10 = arith.constant 10 : index
5+
// CHECK: %c1 = arith.constant 1 : index
6+
// CHECK: %cst = arith.constant 2.500000e+00 : f64
7+
// CHECK: %cst_0 = arith.constant 2.000000e+00 : f64
8+
// CHECK: %cst_1 = arith.constant 0.000000e+00 : f64
9+
// CHECK: %cst_2 = arith.constant 1.000000e+02 : f64
10+
// CHECK: %0 = arith.mulf %arg2, %cst_0 : f64
11+
// CHECK: %1 = scf.for %{{.*}} = %c1 to %c10 step %c1 iter_args(%{{.*}} = %cst) -> (f64) {
12+
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %cst_2 : f64
13+
// CHECK: %{{.*}} = scf.for %{{.*}} = %c1 to %c10 step %c1 iter_args(%{{.*}} = %{{.*}}) -> (f64) {
14+
// CHECK: %{{.*}} = arith.addf %{{.*}}, %0 : f64
15+
// CHECK: scf.yield %{{.*}} : f64
16+
// CHECK: }
17+
// CHECK: scf.yield %{{.*}} : f64
18+
// CHECK: }
19+
// CHECK: %2 = enzyme.autodiff_region(%arg0, %arg1) {
20+
// CHECK: ^bb0(%{{.*}}: f64):
21+
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64
22+
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %0 : f64
23+
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %1 : f64
24+
// CHECK: %{{.*}} = arith.addf %{{.*}}, %cst_1 : f64
25+
// CHECK: %{{.*}} = scf.for %{{.*}} = %c1 to %c10 step %c1 iter_args(%{{.*}} = %{{.*}}) -> (f64) {
26+
// CHECK: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64
27+
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64
28+
// CHECK: scf.yield %{{.*}} : f64
29+
// CHECK: }
30+
// CHECK: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64
31+
// CHECK: enzyme.yield %{{.*}} : f64
32+
// CHECK: } attributes {{.*}} : (f64, f64) -> f64
33+
34+
func.func @foo(%arg0: f64, %arg1: f64,%xx: f64) -> f64 {
35+
36+
%yy_cst = arith.constant 100.0 : f64
37+
%0 = enzyme.autodiff_region(%arg0, %arg1) {
38+
^bb0(%arg2: f64):
39+
// hoistable constant ops
40+
%c0 = arith.constant 0.0 : f64
41+
%c1 = arith.constant 1.0 : f64
42+
%c2 = arith.constant 2.0 : f64
43+
%cx = arith.mulf %c2, %xx : f64
44+
45+
%sq = arith.mulf %arg2, %arg2 : f64
46+
%sqx = arith.mulf %sq, %cx : f64
47+
48+
// hoistable loops
49+
%yy0 = arith.constant 2.5 : f64
50+
%one = arith.constant 1 : index
51+
%ten = arith.constant 10 : index
52+
%yy = scf.for %iv = %one to %ten step %one iter_args(%yy_iter = %yy0) -> (f64) {
53+
%tm = arith.mulf %yy_iter, %yy_cst : f64
54+
%ta = scf.for %jv = %one to %ten step %one iter_args(%tm_iter = %tm) -> (f64) {
55+
%ta = arith.addf %tm, %cx : f64
56+
scf.yield %ta : f64
57+
}
58+
scf.yield %ta : f64
59+
}
60+
61+
%sqxy = arith.mulf %sqx, %yy : f64
62+
%zz0 = arith.addf %sqx, %c0 : f64
63+
%zz = scf.for %iv = %one to %ten step %one iter_args(%zz_iter = %zz0) ->(f64) {
64+
%zm = arith.addf %zz_iter, %sqx : f64
65+
%zout = arith.mulf %zm, %zz_iter : f64
66+
scf.yield %zout : f64
67+
}
68+
69+
%sqxyz = arith.mulf %zz, %sqxy : f64
70+
enzyme.yield %sqxyz : f64
71+
} attributes {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>]} : (f64, f64) -> f64
72+
return %0 : f64
73+
}

0 commit comments

Comments
 (0)