@@ -39,23 +39,23 @@ func.func @test_add_generate_symmetry(%arg0: tensor<3x2x3xf32>) -> tensor<3x2x3x
3939// CHECK-NEXT: return %1 : tensor<3x2x3xf32>
4040// CHECK-NEXT: }
4141
42- func.func @test_dot_propagate (%arg0: tensor <2 x2 x3 xf32 > {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry <<[0 , 1 ]>>]}) -> tensor <2 x2 xf32 > {
42+ func.func @test_dot_general_propagate (%arg0: tensor <2 x2 x3 xf32 > {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry <<[0 , 1 ]>>]}) -> tensor <2 x2 xf32 > {
4343 %cst1 = stablehlo.constant dense <[[[1.0 , 2.0 ], [2.0 , 3.0 ]], [[2.0 , 3.0 ], [3.0 , 4.0 ]], [[2.0 , 3.0 ], [3.0 , 4.0 ]]]> : tensor <3 x2 x2 xf32 >
4444 %0 = stablehlo.dot_general %arg0 , %cst1 , batching_dims = [0 , 1 ] x [1 , 2 ], contracting_dims = [2 ] x [0 ] : (tensor <2 x2 x3 xf32 >, tensor <3 x2 x2 xf32 >) -> tensor <2 x2 xf32 >
4545 return %0 : tensor <2 x2 xf32 >
4646}
47- // CHECK: func.func @test_dot_propagate (%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> {
47+ // CHECK: func.func @test_dot_general_propagate (%arg0: tensor<2x2x3xf32> {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]}) -> tensor<2x2xf32> {
4848// CHECK-NEXT: %cst = stablehlo.constant {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} dense<{{.*}}> : tensor<3x2x2xf32>
4949// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %cst, batching_dims = [0, 1] x [1, 2], contracting_dims = [2] x [0] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[0, 1]>>]} : (tensor<2x2x3xf32>, tensor<3x2x2xf32>) -> tensor<2x2xf32>
5050// CHECK-NEXT: return %0 : tensor<2x2xf32>
5151// CHECK-NEXT: }
5252
53- func.func @test_dot_generate_symmetry (%arg0: tensor <3 x3 x3 xf32 >) -> tensor <3 x3 x3 xf32 > {
53+ func.func @test_dot_general_generate_symmetry (%arg0: tensor <3 x3 x3 xf32 >) -> tensor <3 x3 x3 xf32 > {
5454 %0 = stablehlo.transpose %arg0 , dims = [2 , 1 , 0 ] : (tensor <3 x3 x3 xf32 >) -> tensor <3 x3 x3 xf32 >
5555 %1 = stablehlo.dot_general %arg0 , %0 , batching_dims = [1 ] x [1 ], contracting_dims = [0 ] x [2 ] : (tensor <3 x3 x3 xf32 >, tensor <3 x3 x3 xf32 >) -> tensor <3 x3 x3 xf32 >
5656 return %1 : tensor <3 x3 x3 xf32 >
5757}
58- // CHECK: func.func @test_dot_generate_symmetry (%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> {
58+ // CHECK: func.func @test_dot_general_generate_symmetry (%arg0: tensor<3x3x3xf32>) -> tensor<3x3x3xf32> {
5959// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x3x3xf32>) -> tensor<3x3x3xf32>
6060// CHECK-NEXT: %1 = stablehlo.dot_general %arg0, %0, batching_dims = [1] x [1], contracting_dims = [0] x [2] {enzymexla.partial_symmetry = [#enzymexla.partial_symmetry<<[1, 2]>>]} : (tensor<3x3x3xf32>, tensor<3x3x3xf32>) -> tensor<3x3x3xf32>
6161// CHECK-NEXT: return %1 : tensor<3x3x3xf32>
0 commit comments