Skip to content

Commit ae21aaf

Browse files
committed
Merge remote-tracking branch 'upstream/sycl' into HEAD
2 parents 437fe59 + ee6969f commit ae21aaf

File tree

4 files changed

+180
-38
lines changed

4 files changed

+180
-38
lines changed

compiler-rt/lib/builtins/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -850,9 +850,12 @@ else ()
850850
if (CAN_TARGET_${arch})
851851
cmake_push_check_state()
852852
# TODO: we should probably make most of the checks in builtin-config depend on the target flags.
853-
message(STATUS "Performing additional configure checks with target flags: ${TARGET_${arch}_CFLAGS}")
854853
set(BUILTIN_CFLAGS_${arch} ${BUILTIN_CFLAGS})
855-
list(APPEND CMAKE_REQUIRED_FLAGS ${TARGET_${arch}_CFLAGS} ${BUILTIN_CFLAGS_${arch}})
854+
# CMAKE_REQUIRED_FLAGS must be a space separated string but unlike TARGET_${arch}_CFLAGS,
855+
# BUILTIN_CFLAGS_${arch} is a CMake list, so we have to join it to create a valid command line.
856+
list(JOIN BUILTIN_CFLAGS " " CMAKE_REQUIRED_FLAGS)
857+
set(CMAKE_REQUIRED_FLAGS "${TARGET_${arch}_CFLAGS} ${BUILTIN_CFLAGS_${arch}}")
858+
message(STATUS "Performing additional configure checks with target flags: ${CMAKE_REQUIRED_FLAGS}")
856859
# For ARM archs, exclude any VFP builtins if VFP is not supported
857860
if (${arch} MATCHES "^(arm|armhf|armv7|armv7s|armv7k|armv7m|armv7em|armv8m.main|armv8.1m.main)$")
858861
string(REPLACE ";" " " _TARGET_${arch}_CFLAGS "${TARGET_${arch}_CFLAGS}")

llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp

Lines changed: 113 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,74 @@ namespace {
2222
static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain";
2323
static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR";
2424

25+
Type *getInnermostType(Type *Ty) {
26+
while (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
27+
Ty = ArrayTy->getElementType();
28+
return Ty;
29+
}
30+
31+
Type *replaceInnermostType(Type *Ty, Type *NewInnermostTy) {
32+
if (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
33+
return ArrayType::get(
34+
replaceInnermostType(ArrayTy->getElementType(), NewInnermostTy),
35+
ArrayTy->getNumElements());
36+
return NewInnermostTy;
37+
}
38+
39+
// This function is a copy of stripPointerCastsAndOffsets from Value.cpp,
40+
// simplified and modified to strip non-zero GEP indices as well and also
41+
// find nearest GEP instruction.
42+
Value *stripPointerCastsAndOffsets(Value *V, bool StopOnGEP = false) {
43+
if (!V->getType()->isPointerTy())
44+
return V;
45+
46+
// Even though we don't look through PHI nodes, we could be called on an
47+
// instruction in an unreachable block, which may be on a cycle.
48+
SmallPtrSet<Value *, 4> Visited;
49+
50+
Visited.insert(V);
51+
do {
52+
if (auto *GEP = dyn_cast<GEPOperator>(V)) {
53+
if (StopOnGEP && isa<GetElementPtrInst>(GEP))
54+
return V;
55+
V = GEP->getPointerOperand();
56+
} else if (auto *BC = dyn_cast<BitCastOperator>(V)) {
57+
Value *NewV = BC->getOperand(0);
58+
if (!NewV->getType()->isPointerTy())
59+
return V;
60+
V = NewV;
61+
} else if (auto *ASC = dyn_cast<AddrSpaceCastOperator>(V)) {
62+
V = ASC->getOperand(0);
63+
} else {
64+
if (auto *Call = dyn_cast<CallBase>(V)) {
65+
if (Value *RV = Call->getReturnedArgOperand()) {
66+
V = RV;
67+
// Strip the call instruction, since callee returns its RV
68+
// argument as return value. So, we need to continue stripping.
69+
continue;
70+
}
71+
}
72+
return V;
73+
}
74+
assert(V->getType()->isPointerTy() && "Unexpected operand type!");
75+
} while (Visited.insert(V).second);
76+
77+
return V;
78+
}
79+
80+
TargetExtType *extractMatrixType(StructType *WrapperMatrixTy) {
81+
if (!WrapperMatrixTy)
82+
return nullptr;
83+
TargetExtType *MatrixTy =
84+
dyn_cast<TargetExtType>(WrapperMatrixTy->getElementType(0));
85+
86+
if (!MatrixTy)
87+
return nullptr;
88+
if (MatrixTy->getName() != MATRIX_TYPE)
89+
return nullptr;
90+
return MatrixTy;
91+
}
92+
2593
// This function finds all calls to __spirv_AccessChain function and transforms
2694
// its users and operands to make LLVM IR more SPIR-V friendly.
2795
bool transformAccessChain(Function *F) {
@@ -60,33 +128,59 @@ bool transformAccessChain(Function *F) {
60128
// from sycl::joint_matrix class object if it's used in __spirv_AccessChain
61129
// function call. It's necessary because otherwise OpAccessChain indices
62130
// would be wrong.
63-
Instruction *Ptr =
64-
dyn_cast<Instruction>(CI->getArgOperand(0)->stripPointerCasts());
131+
Instruction *Ptr = dyn_cast<Instruction>(
132+
stripPointerCastsAndOffsets(CI->getArgOperand(0)));
65133
if (!Ptr || !isa<AllocaInst>(Ptr))
66134
continue;
67-
StructType *WrapperMatrixTy =
68-
dyn_cast<StructType>(cast<AllocaInst>(Ptr)->getAllocatedType());
69-
if (!WrapperMatrixTy)
70-
continue;
71-
TargetExtType *MatrixTy =
72-
dyn_cast<TargetExtType>(WrapperMatrixTy->getElementType(0));
73-
if (!MatrixTy)
135+
136+
Type *AllocaTy = cast<AllocaInst>(Ptr)->getAllocatedType();
137+
// It may happen that sycl::joint_matrix class object is wrapped into
138+
// nested arrays. We need to find the innermost type to extract
139+
if (StructType *WrapperMatrixTy =
140+
dyn_cast<StructType>(getInnermostType(AllocaTy))) {
141+
TargetExtType *MatrixTy = extractMatrixType(WrapperMatrixTy);
142+
if (!MatrixTy)
143+
continue;
144+
145+
AllocaInst *Alloca = nullptr;
146+
{
147+
IRBuilder Builder(CI);
148+
IRBuilderBase::InsertPointGuard IG(Builder);
149+
Builder.SetInsertPointPastAllocas(CI->getFunction());
150+
Alloca = Builder.CreateAlloca(replaceInnermostType(AllocaTy, MatrixTy));
151+
Alloca->takeName(Ptr);
152+
}
153+
Ptr->replaceAllUsesWith(Alloca);
154+
Ptr->dropAllReferences();
155+
Ptr->eraseFromParent();
156+
ModuleChanged = true;
157+
}
158+
159+
// In case spirv.CooperativeMatrixKHR is used in arrays, we also need to
160+
// insert GEP to get pointer to target exention type and use it instead of
161+
// pointer to sycl::joint_matrix class object when it is passed to
162+
// __spirv_AccessChain
163+
// First we check if the argument came from a GEP instruction
164+
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(
165+
stripPointerCastsAndOffsets(CI->getArgOperand(0), /*StopOnGEP=*/true));
166+
if (!GEP)
74167
continue;
75-
StringRef Name = MatrixTy->getName();
76-
if (Name != MATRIX_TYPE)
168+
169+
// Check if GEP return type is a pointer to sycl::joint_matrix class object
170+
StructType *WrapperMatrixTy =
171+
dyn_cast<StructType>(GEP->getResultElementType());
172+
if (!extractMatrixType(WrapperMatrixTy))
77173
continue;
78174

79-
AllocaInst *Alloca = nullptr;
175+
// Insert GEP right before the __spirv_AccessChain call
80176
{
81177
IRBuilder Builder(CI);
82-
IRBuilderBase::InsertPointGuard IG(Builder);
83-
Builder.SetInsertPointPastAllocas(CI->getFunction());
84-
Alloca = Builder.CreateAlloca(MatrixTy);
178+
Value *NewGEP =
179+
Builder.CreateInBoundsGEP(WrapperMatrixTy, CI->getArgOperand(0),
180+
{Builder.getInt64(0), Builder.getInt32(0)});
181+
CI->setArgOperand(0, NewGEP);
182+
ModuleChanged = true;
85183
}
86-
Ptr->replaceAllUsesWith(Alloca);
87-
Ptr->dropAllReferences();
88-
Ptr->eraseFromParent();
89-
ModuleChanged = true;
90184
}
91185
return ModuleChanged;
92186
}

llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,69 @@
33

44
; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s
55

6-
; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
7-
; CHECK: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4)
8-
; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0)
9-
106
; ModuleID = 'test.bc'
117
source_filename = "test.cpp"
128
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
139
target triple = "spir64-unknown-unknown"
1410

15-
%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }
11+
%"struct.sycl::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }
12+
%"struct.sycl::_V1::long" = type { i64 }
13+
14+
define weak_odr dso_local spir_kernel void @test(i64 %ind) {
15+
; CHECK-LABEL: define weak_odr dso_local spir_kernel void @test(
16+
; CHECK-SAME: i64 [[IND:%.*]]) {
17+
18+
; non-matrix alloca not touched
19+
; CHECK: [[NOT_MATR:%.*]] = alloca [2 x [4 x %"struct.sycl::_V1::long"]]
20+
; both matrix-related allocas updated to use target extension types
21+
; CHECK-NEXT: [[MATR:%.*]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
22+
; CHECK-NEXT: [[MATR_ARR:%.*]] = alloca [2 x [4 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]]
23+
24+
; CHECK-NEXT: [[ASCAST:%.*]] = addrspacecast ptr [[MATR]] to ptr addrspace(4)
25+
; no gep inserted, since not needed
26+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST]], i64 noundef 0)
27+
28+
; CHECK: [[GEP:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr [[MATR_ARR]], i64 0, i64 [[IND]], i64 [[IND]]
29+
; CHECK-NEXT: [[ASCAST_1:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4)
30+
; CHECK-NEXT: [[ASCAST_2:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4)
31+
; gep is inserted for each of the accesschain calls to extract target extension type
32+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_1]], i64 0, i32 0
33+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP2]], i64 noundef 0)
34+
; CHECK: [[TMP5:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_2]], i64 0, i32 0
35+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP5]], i64 noundef 0)
36+
37+
; negative test - not touching non-matrix code
38+
; CHECK: [[GEP_1:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr [[NOT_MATR]], i64 0, i64 [[IND]], i64 [[IND]]
39+
; CHECK-NEXT: [[ASCAST_3:%.*]] = addrspacecast ptr [[GEP_1]] to ptr addrspace(4)
40+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST_3]], i64 noundef 0)
1641

17-
define weak_odr dso_local spir_kernel void @test() {
1842
entry:
19-
%0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
20-
%1 = addrspacecast ptr %0 to ptr addrspace(4)
21-
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0)
43+
; allocas
44+
%matr = alloca %"struct.sycl::joint_matrix", align 8
45+
%matr.arr = alloca [2 x [4 x %"struct.sycl::joint_matrix"]], align 8
46+
%not.matr = alloca [2 x [4 x %"struct.sycl::_V1::long"]], align 8
47+
48+
; simple case
49+
%ascast = addrspacecast ptr %matr to ptr addrspace(4)
50+
%0 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast, i64 noundef 0)
51+
%1 = load i8, ptr addrspace(4) %0
52+
53+
; gep with non-zero inidices and multiple access chains per 1 alloca
54+
%gep = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr %matr.arr, i64 0, i64 %ind, i64 %ind
55+
%ascast.1 = addrspacecast ptr %gep to ptr addrspace(4)
56+
%ascast.2 = addrspacecast ptr %gep to ptr addrspace(4)
57+
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.1, i64 noundef 0)
2258
%3 = load i8, ptr addrspace(4) %2
59+
%4 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.2, i64 noundef 0)
60+
%5 = load i8, ptr addrspace(4) %4
61+
62+
; negative test - not touching non-matrix code
63+
%gep.1 = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr %not.matr, i64 0, i64 %ind, i64 %ind
64+
%ascast.3 = addrspacecast ptr %gep.1 to ptr addrspace(4)
65+
%6 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.3, i64 noundef 0)
66+
%7 = load i8, ptr addrspace(4) %6
67+
2368
ret void
2469
}
2570

26-
declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef)
71+
declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef, i64 noundef)
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# commit ea0f3a1f5f15f9af7bf40bd13669afeb9ada569c
2-
# Merge: bb64b3e9f6d3 4a89e1c69a65
3-
# Author: Martin Grant <martin.morrisongrant@codeplay.com>
4-
# Date: Thu Dec 19 11:26:01 2024 +0000
5-
# Merge pull request #2277 from igchor/cooperative_fix
6-
# [Spec] fix urKernelSuggestMaxCooperativeGroupCountExp
7-
set(UNIFIED_RUNTIME_TAG ea0f3a1f5f15f9af7bf40bd13669afeb9ada569c)
1+
# commit 232e62f5221d565ec40d051d3c640b836ca91244
2+
# Merge: 76a96238 59b37e3f
3+
# Author: aarongreig <aaron.greig@codeplay.com>
4+
# Date: Mon Dec 23 18:26:58 2024 +0000
5+
# Merge pull request #2498 from Bensuo/fabio/fix_l0_old_loader_no_translate
6+
# Update usage of zeCommandListImmediateAppendCommandListsExp to use dlsym
7+
set(UNIFIED_RUNTIME_TAG 232e62f5221d565ec40d051d3c640b836ca91244)

0 commit comments

Comments
 (0)