|
1 | | -// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s |
| 1 | +// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=reshape_dynamic_slice(1);reshape_licm(1);transpose_dynamic_slice;transpose_licm(1)" --transform-interpreter --enzyme-hlo-remove-transform --enzyme-hlo-opt %s | FileCheck %s |
2 | 2 |
|
3 | 3 | module { |
4 | 4 | func.func @main(%arg0: tensor<10xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<10xf32>) -> tensor<10xf32> { |
@@ -165,3 +165,129 @@ module { |
165 | 165 | // CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [2, 0, 1] : (tensor<5x10x4xf32>) -> tensor<4x5x10xf32> |
166 | 166 | // CHECK-NEXT: return %2 : tensor<4x5x10xf32> |
167 | 167 | // CHECK-NEXT: } |
| 168 | + |
| 169 | +module { |
| 170 | + func.func @main(%arg0: tensor<10xf64>) -> tensor<10xf64> { |
| 171 | + %c = stablehlo.constant dense<1> : tensor<i32> |
| 172 | + %c_0 = stablehlo.constant dense<0> : tensor<i64> |
| 173 | + %c_1 = stablehlo.constant dense<10> : tensor<i64> |
| 174 | + %c_2 = stablehlo.constant dense<1> : tensor<i64> |
| 175 | + %cst = stablehlo.constant dense<0.000000e+00> : tensor<10xf64> |
| 176 | + %0 = stablehlo.reshape %arg0 : (tensor<10xf64>) -> tensor<10x1xf64> |
| 177 | + %1:2 = stablehlo.while(%iterArg = %c_0, %iterArg_3 = %cst) : tensor<i64>, tensor<10xf64> |
| 178 | + cond { |
| 179 | + %2 = stablehlo.compare LT, %iterArg, %c_1 : (tensor<i64>, tensor<i64>) -> tensor<i1> |
| 180 | + stablehlo.return %2 : tensor<i1> |
| 181 | + } do { |
| 182 | + %2 = stablehlo.add %c_2, %iterArg : tensor<i64> |
| 183 | + %3 = stablehlo.convert %2 : (tensor<i64>) -> tensor<i32> |
| 184 | + %4 = stablehlo.subtract %3, %c : tensor<i32> |
| 185 | + %5 = stablehlo.dynamic_slice %0, %iterArg, %c_0, sizes = [1, 1] : (tensor<10x1xf64>, tensor<i64>, tensor<i64>) -> tensor<1x1xf64> |
| 186 | + %6 = stablehlo.reshape %5 : (tensor<1x1xf64>) -> tensor<1xf64> |
| 187 | + %7 = stablehlo.dynamic_update_slice %iterArg_3, %6, %4 : (tensor<10xf64>, tensor<1xf64>, tensor<i32>) -> tensor<10xf64> |
| 188 | + stablehlo.return %2, %7 : tensor<i64>, tensor<10xf64> |
| 189 | + } |
| 190 | + return %1#1 : tensor<10xf64> |
| 191 | + } |
| 192 | +} |
| 193 | + |
| 194 | +// CHECK: func.func @main(%arg0: tensor<10xf64>) -> tensor<10xf64> { |
| 195 | +// CHECK-NEXT: return %arg0 : tensor<10xf64> |
| 196 | +// CHECK-NEXT: } |
| 197 | + |
| 198 | +module { |
| 199 | + func.func @main(%arg0: tensor<5x4x3xf32>) -> tensor<4x5x3xf32> { |
| 200 | + %c = stablehlo.constant dense<0> : tensor<i32> |
| 201 | + %cst = stablehlo.constant dense<0.000000e+00> : tensor<4x5x3xf32> |
| 202 | + %c_0 = stablehlo.constant dense<1> : tensor<i32> |
| 203 | + %c_1 = stablehlo.constant dense<0> : tensor<i64> |
| 204 | + %c_2 = stablehlo.constant dense<4> : tensor<i64> |
| 205 | + %c_3 = stablehlo.constant dense<1> : tensor<i64> |
| 206 | + %0 = stablehlo.broadcast_in_dim %arg0, dims = [2, 0, 3] : (tensor<5x4x3xf32>) -> tensor<4x1x5x3xf32> |
| 207 | + %1:2 = stablehlo.while(%iterArg = %c_1, %iterArg_4 = %cst) : tensor<i64>, tensor<4x5x3xf32> |
| 208 | + cond { |
| 209 | + %2 = stablehlo.compare LT, %iterArg, %c_2 : (tensor<i64>, tensor<i64>) -> tensor<i1> |
| 210 | + stablehlo.return %2 : tensor<i1> |
| 211 | + } do { |
| 212 | + %2 = stablehlo.add %c_3, %iterArg : tensor<i64> |
| 213 | + %3 = stablehlo.convert %2 : (tensor<i64>) -> tensor<i32> |
| 214 | + %4 = stablehlo.subtract %3, %c_0 : tensor<i32> |
| 215 | + %5 = stablehlo.dynamic_slice %0, %iterArg, %c_1, %c_1, %c_1, sizes = [1, 1, 5, 3] : (tensor<4x1x5x3xf32>, tensor<i64>, tensor<i64>, tensor<i64>, tensor<i64>) -> tensor<1x1x5x3xf32> |
| 216 | + %6 = stablehlo.reshape %5 : (tensor<1x1x5x3xf32>) -> tensor<1x5x3xf32> |
| 217 | + %7 = stablehlo.dynamic_update_slice %iterArg_4, %6, %4, %c, %c : (tensor<4x5x3xf32>, tensor<1x5x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<4x5x3xf32> |
| 218 | + stablehlo.return %2, %7 : tensor<i64>, tensor<4x5x3xf32> |
| 219 | + } |
| 220 | + return %1#1 : tensor<4x5x3xf32> |
| 221 | + } |
| 222 | +} |
| 223 | + |
| 224 | +// CHECK: func.func @main(%arg0: tensor<5x4x3xf32>) -> tensor<4x5x3xf32> { |
| 225 | +// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0, 2] : (tensor<5x4x3xf32>) -> tensor<4x5x3xf32> |
| 226 | +// CHECK-NEXT: return %0 : tensor<4x5x3xf32> |
| 227 | +// CHECK-NEXT: } |
| 228 | + |
| 229 | +module { |
| 230 | + func.func @main(%arg0: tensor<5x4x3xf32>, %arg1: tensor<3x1x4x1x5xf32>) -> tensor<5x4x3xf32> { |
| 231 | + %c = stablehlo.constant dense<0> : tensor<i32> |
| 232 | + %c_0 = stablehlo.constant dense<1> : tensor<i32> |
| 233 | + %c_1 = stablehlo.constant dense<0> : tensor<i64> |
| 234 | + %c_2 = stablehlo.constant dense<1> : tensor<i64> |
| 235 | + %c_3 = stablehlo.constant dense<4> : tensor<i64> |
| 236 | + %0 = stablehlo.transpose %arg1, dims = [4, 1, 2, 3, 0] : (tensor<3x1x4x1x5xf32>) -> tensor<5x1x4x1x3xf32> |
| 237 | + %1:2 = stablehlo.while(%iterArg = %c_1, %iterArg_3 = %arg0) : tensor<i64>, tensor<5x4x3xf32> |
| 238 | + cond { |
| 239 | + %2 = stablehlo.compare LT, %iterArg, %c_3 : (tensor<i64>, tensor<i64>) -> tensor<i1> |
| 240 | + stablehlo.return %2 : tensor<i1> |
| 241 | + } do { |
| 242 | + %2 = stablehlo.add %c_2, %iterArg : tensor<i64> |
| 243 | + %3 = stablehlo.convert %2 : (tensor<i64>) -> tensor<i32> |
| 244 | + %4 = stablehlo.subtract %3, %c_0 : tensor<i32> |
| 245 | + %5 = stablehlo.dynamic_slice %0, %c, %c, %4, %c, %c, sizes = [5, 1, 1, 1, 3] : (tensor<5x1x4x1x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5x1x1x1x3xf32> |
| 246 | + %6 = stablehlo.reshape %5 : (tensor<5x1x1x1x3xf32>) -> tensor<5x1x3xf32> |
| 247 | + %7 = stablehlo.dynamic_update_slice %iterArg_3, %6, %c, %4, %c : (tensor<5x4x3xf32>, tensor<5x1x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5x4x3xf32> |
| 248 | + stablehlo.return %2, %7 : tensor<i64>, tensor<5x4x3xf32> |
| 249 | + } |
| 250 | + return %1#1 : tensor<5x4x3xf32> |
| 251 | + } |
| 252 | +} |
| 253 | + |
| 254 | +// CHECK: func.func @main(%arg0: tensor<5x4x3xf32>, %arg1: tensor<3x1x4x1x5xf32>) -> tensor<5x4x3xf32> { |
| 255 | +// CHECK-NEXT: %0 = stablehlo.transpose %arg1, dims = [4, 1, 2, 3, 0] : (tensor<3x1x4x1x5xf32>) -> tensor<5x1x4x1x3xf32> |
| 256 | +// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<5x1x4x1x3xf32>) -> tensor<5x4x3xf32> |
| 257 | +// CHECK-NEXT: return %1 : tensor<5x4x3xf32> |
| 258 | +// CHECK-NEXT: } |
| 259 | + |
| 260 | +module { |
| 261 | + func.func @main(%arg0: tensor<1x3x4x1x5xf32>, %arg1: tensor<5x4x3xf32>) -> tensor<1x3x4x1x5xf32> { |
| 262 | + %c = stablehlo.constant dense<0> : tensor<i32> |
| 263 | + %c_0 = stablehlo.constant dense<1> : tensor<i32> |
| 264 | + %c_1 = stablehlo.constant dense<0> : tensor<i64> |
| 265 | + %c_2 = stablehlo.constant dense<1> : tensor<i64> |
| 266 | + %c_3 = stablehlo.constant dense<3> : tensor<i64> |
| 267 | + %0:2 = stablehlo.while(%iterArg = %c_1, %iterArg_4 = %arg0) : tensor<i64>, tensor<1x3x4x1x5xf32> |
| 268 | + cond { |
| 269 | + %1 = stablehlo.compare LT, %iterArg, %c_3 : (tensor<i64>, tensor<i64>) -> tensor<i1> |
| 270 | + stablehlo.return %1 : tensor<i1> |
| 271 | + } do { |
| 272 | + %1 = stablehlo.add %c_2, %iterArg : tensor<i64> |
| 273 | + %2 = stablehlo.convert %1 : (tensor<i64>) -> tensor<i32> |
| 274 | + %3 = stablehlo.subtract %2, %c_0 : tensor<i32> |
| 275 | + %4 = stablehlo.dynamic_slice %arg1, %c, %3, %c, sizes = [5, 1, 3] : (tensor<5x4x3xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<5x1x3xf32> |
| 276 | + %5 = stablehlo.reshape %4 : (tensor<5x1x3xf32>) -> tensor<5x1x1x3x1xf32> |
| 277 | + %6 = stablehlo.transpose %5, dims = [4, 3, 2, 1, 0] : (tensor<5x1x1x3x1xf32>) -> tensor<1x3x1x1x5xf32> |
| 278 | + %7 = stablehlo.dynamic_update_slice %iterArg_4, %6, %c, %c, %3, %c, %c : (tensor<1x3x4x1x5xf32>, tensor<1x3x1x1x5xf32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<1x3x4x1x5xf32> |
| 279 | + stablehlo.return %1, %7 : tensor<i64>, tensor<1x3x4x1x5xf32> |
| 280 | + } |
| 281 | + return %0#1 : tensor<1x3x4x1x5xf32> |
| 282 | + } |
| 283 | +} |
| 284 | + |
| 285 | +// CHECK: func.func @main(%arg0: tensor<1x3x4x1x5xf32>, %arg1: tensor<5x4x3xf32>) -> tensor<1x3x4x1x5xf32> { |
| 286 | +// CHECK-NEXT: %0 = stablehlo.slice %arg1 [0:5, 0:3, 0:3] : (tensor<5x4x3xf32>) -> tensor<5x3x3xf32> |
| 287 | +// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<5x3x3xf32>) -> tensor<5x3x1x3x1xf32> |
| 288 | +// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [4, 3, 2, 1, 0] : (tensor<5x3x1x3x1xf32>) -> tensor<1x3x1x3x5xf32> |
| 289 | +// CHECK-NEXT: %3 = stablehlo.reshape %2 : (tensor<1x3x1x3x5xf32>) -> tensor<1x3x3x1x5xf32> |
| 290 | +// CHECK-NEXT: %4 = stablehlo.slice %arg0 [0:1, 0:3, 3:4, 0:1, 0:5] : (tensor<1x3x4x1x5xf32>) -> tensor<1x3x1x1x5xf32> |
| 291 | +// CHECK-NEXT: %5 = stablehlo.concatenate %3, %4, dim = 2 : (tensor<1x3x3x1x5xf32>, tensor<1x3x1x1x5xf32>) -> tensor<1x3x4x1x5xf32> |
| 292 | +// CHECK-NEXT: return %5 : tensor<1x3x4x1x5xf32> |
| 293 | +// CHECK-NEXT: } |
0 commit comments