|
3 | 3 |
|
4 | 4 | ; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s
|
5 | 5 |
|
6 |
| -; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) |
7 |
| -; CHECK: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4) |
8 |
| -; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0) |
9 |
| - |
10 | 6 | ; ModuleID = 'test.bc'
|
11 | 7 | source_filename = "test.cpp"
|
12 | 8 | target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
|
13 | 9 | target triple = "spir64-unknown-unknown"
|
14 | 10 |
|
15 |
| -%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) } |
| 11 | +%"struct.sycl::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) } |
| 12 | +%"struct.sycl::_V1::long" = type { i64 } |
| 13 | + |
| 14 | +define weak_odr dso_local spir_kernel void @test(i64 %ind) { |
| 15 | +; CHECK-LABEL: define weak_odr dso_local spir_kernel void @test( |
| 16 | +; CHECK-SAME: i64 [[IND:%.*]]) { |
| 17 | + |
| 18 | +; non-matrix alloca not touched |
| 19 | +; CHECK: [[NOT_MATR:%.*]] = alloca [2 x [4 x %"struct.sycl::_V1::long"]] |
| 20 | +; both matrix-related allocas updated to use target extension types |
| 21 | +; CHECK-NEXT: [[MATR:%.*]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) |
| 22 | +; CHECK-NEXT: [[MATR_ARR:%.*]] = alloca [2 x [4 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]] |
| 23 | + |
| 24 | +; CHECK-NEXT: [[ASCAST:%.*]] = addrspacecast ptr [[MATR]] to ptr addrspace(4) |
| 25 | +; no gep inserted, since not needed |
| 26 | +; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST]], i64 noundef 0) |
| 27 | + |
| 28 | +; CHECK: [[GEP:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr [[MATR_ARR]], i64 0, i64 [[IND]], i64 [[IND]] |
| 29 | +; CHECK-NEXT: [[ASCAST_1:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4) |
| 30 | +; CHECK-NEXT: [[ASCAST_2:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4) |
| 31 | +; gep is inserted for each of the accesschain calls to extract target extension type |
| 32 | +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_1]], i64 0, i32 0 |
| 33 | +; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP2]], i64 noundef 0) |
| 34 | +; CHECK: [[TMP5:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_2]], i64 0, i32 0 |
| 35 | +; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP5]], i64 noundef 0) |
| 36 | + |
| 37 | +; negative test - not touching non-matrix code |
| 38 | +; CHECK: [[GEP_1:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr [[NOT_MATR]], i64 0, i64 [[IND]], i64 [[IND]] |
| 39 | +; CHECK-NEXT: [[ASCAST_3:%.*]] = addrspacecast ptr [[GEP_1]] to ptr addrspace(4) |
| 40 | +; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST_3]], i64 noundef 0) |
16 | 41 |
|
17 |
| -define weak_odr dso_local spir_kernel void @test() { |
18 | 42 | entry:
|
19 |
| - %0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8 |
20 |
| - %1 = addrspacecast ptr %0 to ptr addrspace(4) |
21 |
| - %2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0) |
| 43 | + ; allocas |
| 44 | + %matr = alloca %"struct.sycl::joint_matrix", align 8 |
| 45 | + %matr.arr = alloca [2 x [4 x %"struct.sycl::joint_matrix"]], align 8 |
| 46 | + %not.matr = alloca [2 x [4 x %"struct.sycl::_V1::long"]], align 8 |
| 47 | + |
| 48 | + ; simple case |
| 49 | + %ascast = addrspacecast ptr %matr to ptr addrspace(4) |
| 50 | + %0 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast, i64 noundef 0) |
| 51 | + %1 = load i8, ptr addrspace(4) %0 |
| 52 | + |
| 53 | + ; gep with non-zero inidices and multiple access chains per 1 alloca |
| 54 | + %gep = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr %matr.arr, i64 0, i64 %ind, i64 %ind |
| 55 | + %ascast.1 = addrspacecast ptr %gep to ptr addrspace(4) |
| 56 | + %ascast.2 = addrspacecast ptr %gep to ptr addrspace(4) |
| 57 | + %2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.1, i64 noundef 0) |
22 | 58 | %3 = load i8, ptr addrspace(4) %2
|
| 59 | + %4 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.2, i64 noundef 0) |
| 60 | + %5 = load i8, ptr addrspace(4) %4 |
| 61 | + |
| 62 | + ; negative test - not touching non-matrix code |
| 63 | + %gep.1 = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr %not.matr, i64 0, i64 %ind, i64 %ind |
| 64 | + %ascast.3 = addrspacecast ptr %gep.1 to ptr addrspace(4) |
| 65 | + %6 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.3, i64 noundef 0) |
| 66 | + %7 = load i8, ptr addrspace(4) %6 |
| 67 | + |
23 | 68 | ret void
|
24 | 69 | }
|
25 | 70 |
|
26 |
| -declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef) |
| 71 | +declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef, i64 noundef) |
0 commit comments