Skip to content

Commit 2d2acf4

Browse files
Convert the kernel state back to a reference when needed. (#715)
* Convert the kernel state back to a reference when needed. We're currently passing the kernel state object by value, disregarding the typical Julia calling convention, because there's known issues with `byval` lowering on NVPTX. For compatibility with back-ends that do not support passing kernel arguments by actual values, provide a pass that's conceptually the inverse of `lower_byval`, instead rewriting the kernel state object to be passed by reference, and loading from it at the beginning of the kernel. * `add_input_arguments!` for other backends (#718) Allows other backends to pass additional hidden arguments that can be accessed through intrinsics. Required for OpenCL device-side RNG support, where additional shared memory must be passed as arguments to the kernel. Co-authored-by: Simeon David Schaub <simeon@schaub.rocks>
1 parent 2effdee commit 2d2acf4

File tree

4 files changed

+322
-192
lines changed

4 files changed

+322
-192
lines changed

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ runtime_slug(@nospecialize(job::CompilerJob)) = error("Not implemented")
267267
kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing
268268

269269
# Does the target need to pass kernel arguments by value?
270-
needs_byval(@nospecialize(job::CompilerJob)) = true
270+
pass_by_value(@nospecialize(job::CompilerJob)) = true
271271

272272
# whether pointer is a valid call target
273273
valid_function_pointer(@nospecialize(job::CompilerJob), ptr::Ptr{Cvoid}) = false

src/irgen.jl

Lines changed: 270 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ function irgen(@nospecialize(job::CompilerJob))
8282

8383
# minimal required optimization
8484
@tracepoint "rewrite" begin
85-
if job.config.kernel && needs_byval(job)
85+
if job.config.kernel && pass_by_value(job)
8686
# pass all bitstypes by value; by default Julia passes aggregates by reference
8787
# (this improves performance, and is mandated by certain back-ends like SPIR-V).
8888
args = classify_arguments(job, function_type(entry))
@@ -256,10 +256,11 @@ end
256256
## kernel promotion
257257

258258
@enum ArgumentCC begin
259-
BITS_VALUE # bitstype, passed as value
260-
BITS_REF # bitstype, passed as pointer
261-
MUT_REF # jl_value_t*, or the anonymous equivalent
262-
GHOST # not passed
259+
BITS_VALUE # bitstype, passed as value
260+
BITS_REF # bitstype, passed as pointer
261+
MUT_REF # jl_value_t*, or the anonymous equivalent
262+
GHOST # not passed
263+
KERNEL_STATE # the kernel state argument
263264
end
264265

265266
# Determine the calling convention of a the arguments of a Julia function, given the
@@ -270,7 +271,8 @@ end
270271
# - `name`: the name of the argument
271272
# - `idx`: the index of the argument in the LLVM function type, or `nothing` if the argument
272273
# is not passed at the LLVM level.
273-
function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType)
274+
function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType;
275+
post_optimization::Bool=false)
274276
source_sig = job.source.specTypes
275277
source_types = [source_sig.parameters...]
276278

@@ -282,9 +284,15 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu
282284

283285
codegen_types = parameters(codegen_ft)
284286

285-
args = []
286-
codegen_i = 1
287-
for (source_i, (source_typ, source_name)) in enumerate(zip(source_types, source_argnames))
287+
if post_optimization && kernel_state_type(job) !== Nothing
288+
args = []
289+
push!(args, (cc=KERNEL_STATE, typ=kernel_state_type(job), name=:kernel_state, idx=1))
290+
codegen_i = 2
291+
else
292+
args = []
293+
codegen_i = 1
294+
end
295+
for (source_typ, source_name) in zip(source_types, source_argnames)
288296
if isghosttype(source_typ) || Core.Compiler.isconstType(source_typ)
289297
push!(args, (cc=GHOST, typ=source_typ, name=source_name, idx=nothing))
290298
continue
@@ -817,3 +825,256 @@ function kernel_state_value(state)
817825
call_function(llvm_f, state)
818826
end
819827
end
828+
829+
# convert kernel state argument from pass-by-value to pass-by-reference
830+
#
831+
# the kernel state argument is always passed by value to avoid codegen issues with byval.
832+
# some back-ends however do not support passing kernel arguments by value, so this pass
833+
# serves to convert that argument (and is conceptually the inverse of `lower_byval`).
834+
function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
835+
f::LLVM.Function)
836+
ft = function_type(f)
837+
838+
# check if we even need a kernel state argument
839+
state = kernel_state_type(job)
840+
if state === Nothing
841+
return f
842+
end
843+
844+
T_state = convert(LLVMType, state)
845+
846+
# find the kernel state parameter (should be the first argument)
847+
if isempty(parameters(ft)) || value_type(parameters(f)[1]) != T_state
848+
return f
849+
end
850+
851+
@tracepoint "kernel state to reference" begin
852+
# generate the new function type & definition
853+
new_types = LLVM.LLVMType[]
854+
# convert the first parameter (kernel state) to a pointer
855+
push!(new_types, LLVM.PointerType(T_state))
856+
# keep all other parameters as-is
857+
for i in 2:length(parameters(ft))
858+
push!(new_types, parameters(ft)[i])
859+
end
860+
861+
new_ft = LLVM.FunctionType(return_type(ft), new_types)
862+
new_f = LLVM.Function(mod, "", new_ft)
863+
linkage!(new_f, linkage(f))
864+
865+
# name the parameters
866+
LLVM.name!(parameters(new_f)[1], "state_ptr")
867+
for (i, (arg, new_arg)) in enumerate(zip(parameters(f)[2:end], parameters(new_f)[2:end]))
868+
LLVM.name!(new_arg, LLVM.name(arg))
869+
end
870+
871+
# emit IR performing the "conversions"
872+
new_args = LLVM.Value[]
873+
@dispose builder=IRBuilder() begin
874+
entry = BasicBlock(new_f, "conversion")
875+
position!(builder, entry)
876+
877+
# load the kernel state value from the pointer
878+
state_val = load!(builder, T_state, parameters(new_f)[1], "state")
879+
push!(new_args, state_val)
880+
881+
# all other arguments are passed through directly
882+
for i in 2:length(parameters(new_f))
883+
push!(new_args, parameters(new_f)[i])
884+
end
885+
886+
# map the arguments
887+
value_map = Dict{LLVM.Value, LLVM.Value}(
888+
param => new_args[i] for (i,param) in enumerate(parameters(f))
889+
)
890+
value_map[f] = new_f
891+
892+
clone_into!(new_f, f; value_map,
893+
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
894+
895+
# fall through
896+
br!(builder, blocks(new_f)[2])
897+
end
898+
899+
# set the attributes for the state pointer parameter
900+
attrs = parameter_attributes(new_f, 1)
901+
# the pointer itself cannot be captured since we immediately load from it
902+
push!(attrs, EnumAttribute("nocapture", 0))
903+
# each kernel state is separate
904+
push!(attrs, EnumAttribute("noalias", 0))
905+
# the state is read-only
906+
push!(attrs, EnumAttribute("readonly", 0))
907+
908+
# remove the old function
909+
fn = LLVM.name(f)
910+
@assert isempty(uses(f))
911+
replace_metadata_uses!(f, new_f)
912+
erase!(f)
913+
LLVM.name!(new_f, fn)
914+
915+
# minimal optimization
916+
@dispose pb=NewPMPassBuilder() begin
917+
add!(pb, SimplifyCFGPass())
918+
run!(pb, new_f, llvm_machine(job.config.target))
919+
end
920+
921+
return new_f
922+
end
923+
end
924+
925+
function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
926+
entry::LLVM.Function, kernel_intrinsics::Dict)
927+
entry_fn = LLVM.name(entry)
928+
929+
# figure out which intrinsics are used and need to be added as arguments
930+
used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn
931+
haskey(functions(mod), intr_fn)
932+
end |> collect
933+
nargs = length(used_intrinsics)
934+
935+
# determine which functions need these arguments
936+
worklist = Set{LLVM.Function}([entry])
937+
for intr_fn in used_intrinsics
938+
push!(worklist, functions(mod)[intr_fn])
939+
end
940+
worklist_length = 0
941+
while worklist_length != length(worklist)
942+
# iteratively discover functions that use an intrinsic or any function calling it
943+
worklist_length = length(worklist)
944+
additions = LLVM.Function[]
945+
for f in worklist, use in uses(f)
946+
inst = user(use)::Instruction
947+
bb = LLVM.parent(inst)
948+
new_f = LLVM.parent(bb)
949+
in(new_f, worklist) || push!(additions, new_f)
950+
end
951+
for f in additions
952+
push!(worklist, f)
953+
end
954+
end
955+
for intr_fn in used_intrinsics
956+
delete!(worklist, functions(mod)[intr_fn])
957+
end
958+
959+
# add the arguments
960+
# NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
961+
workmap = Dict{LLVM.Function, LLVM.Function}()
962+
for f in worklist
963+
fn = LLVM.name(f)
964+
ft = function_type(f)
965+
LLVM.name!(f, fn * ".orig")
966+
# create a new function
967+
new_param_types = LLVMType[parameters(ft)...]
968+
969+
for intr_fn in used_intrinsics
970+
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ)
971+
push!(new_param_types, llvm_typ)
972+
end
973+
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
974+
new_f = LLVM.Function(mod, fn, new_ft)
975+
linkage!(new_f, linkage(f))
976+
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
977+
LLVM.name!(new_arg, LLVM.name(arg))
978+
end
979+
for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
980+
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
981+
end
982+
983+
workmap[f] = new_f
984+
end
985+
986+
# clone and rewrite the function bodies.
987+
# we don't need to rewrite much as the arguments are added last.
988+
for (f, new_f) in workmap
989+
# map the arguments
990+
value_map = Dict{LLVM.Value, LLVM.Value}()
991+
for (param, new_param) in zip(parameters(f), parameters(new_f))
992+
LLVM.name!(new_param, LLVM.name(param))
993+
value_map[param] = new_param
994+
end
995+
996+
value_map[f] = new_f
997+
clone_into!(new_f, f; value_map,
998+
changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)
999+
1000+
# we can't remove this function yet, as we might still need to rewrite any called,
1001+
# but remove the IR already
1002+
empty!(f)
1003+
end
1004+
1005+
# drop unused constants that may be referring to the old functions
1006+
# XXX: can we do this differently?
1007+
for f in worklist
1008+
prune_constexpr_uses!(f)
1009+
end
1010+
1011+
# update other uses of the old function, modifying call sites to pass the arguments
1012+
function rewrite_uses!(f, new_f)
1013+
# update uses
1014+
@dispose builder=IRBuilder() begin
1015+
for use in uses(f)
1016+
val = user(use)
1017+
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
1018+
callee_f = LLVM.parent(LLVM.parent(val))
1019+
# forward the arguments
1020+
position!(builder, val)
1021+
new_val = if val isa LLVM.CallInst
1022+
call!(builder, function_type(new_f), new_f,
1023+
[arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
1024+
operand_bundles(val))
1025+
else
1026+
# TODO: invoke and callbr
1027+
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
1028+
end
1029+
callconv!(new_val, callconv(val))
1030+
1031+
replace_uses!(val, new_val)
1032+
@assert isempty(uses(val))
1033+
erase!(val)
1034+
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
1035+
# XXX: why isn't this caught by the value materializer above?
1036+
target = operands(val)[1]
1037+
@assert target == f
1038+
new_val = LLVM.const_bitcast(new_f, value_type(val))
1039+
rewrite_uses!(val, new_val)
1040+
# we can't simply replace this constant expression, as it may be used
1041+
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)
1042+
1043+
# drop the old constant if it is unused
1044+
# XXX: can we do this differently?
1045+
if isempty(uses(val))
1046+
LLVM.unsafe_destroy!(val)
1047+
end
1048+
else
1049+
error("Cannot rewrite unknown use of function: $val")
1050+
end
1051+
end
1052+
end
1053+
end
1054+
for (f, new_f) in workmap
1055+
rewrite_uses!(f, new_f)
1056+
@assert isempty(uses(f))
1057+
erase!(f)
1058+
end
1059+
1060+
# replace uses of the intrinsics with references to the input arguments
1061+
for (i, intr_fn) in enumerate(used_intrinsics)
1062+
intr = functions(mod)[intr_fn]
1063+
for use in uses(intr)
1064+
val = user(use)
1065+
callee_f = LLVM.parent(LLVM.parent(val))
1066+
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
1067+
replace_uses!(val, parameters(callee_f)[end-nargs+i])
1068+
else
1069+
error("Cannot rewrite unknown use of function: $val")
1070+
end
1071+
1072+
@assert isempty(uses(val))
1073+
erase!(val)
1074+
end
1075+
@assert isempty(uses(intr))
1076+
erase!(intr)
1077+
end
1078+
1079+
return functions(mod)[entry_fn]
1080+
end

0 commit comments

Comments
 (0)