Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "GPUCompiler"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "1.7.0"
version = "1.7.1"
authors = ["Tim Besard <tim.besard@gmail.com>"]

[deps]
Expand Down
104 changes: 54 additions & 50 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,14 @@ const __llvm_initialized = Ref(false)
erase!(call)
end
end

# minimal optimization to convert the inttoptr/call into a direct call
@dispose pb=NewPMPassBuilder() begin
add!(pb, NewPMFunctionPassManager()) do fpm
add!(fpm, InstCombinePass())
end
run!(pb, ir, llvm_machine(job.config.target))
end
end

# all deferred compilations should have been resolved
Expand Down Expand Up @@ -312,10 +320,15 @@ const __llvm_initialized = Ref(false)
end

@tracepoint "IR post-processing" begin
# mark everything internal except for entrypoints and any exported
# global variables. this makes sure that the optimizer can, e.g.,
# rewrite function signatures.
# mark the kernel entry-point functions (optimization may need it)
if job.config.kernel
mark_kernel!(entry)
end

if job.config.toplevel
# mark everything internal except for entrypoints and any exported
# global variables. this makes sure that the optimizer can, e.g.,
# rewrite function signatures.
preserved_gvs = collect(values(jobs))
for gvar in globals(ir)
if linkage(gvar) == LLVM.API.LLVMExternalLinkage
Expand All @@ -331,64 +344,55 @@ const __llvm_initialized = Ref(false)
run!(pm, ir)
end
end
end

# mark the kernel entry-point functions (optimization may need it)
if job.config.kernel
push!(metadata(ir)["julia.kernel"], MDNode([entry]))

# IDEA: save all jobs, not only kernels, and save other attributes
# so that we can reconstruct the CompileJob instead of setting it globally
end

if job.config.toplevel && job.config.optimize
@tracepoint "optimization" begin
optimize!(job, ir; job.config.opt_level)
finish_linked_module!(job, ir)

if job.config.optimize
@tracepoint "optimization" begin
optimize!(job, ir; job.config.opt_level)

# deferred codegen has some special optimization requirements,
# which also need to happen _after_ regular optimization.
# XXX: make these part of the optimizer pipeline?
if has_deferred_jobs
@dispose pb=NewPMPassBuilder() begin
add!(pb, NewPMFunctionPassManager()) do fpm
add!(fpm, InstCombinePass())
end
add!(pb, AlwaysInlinerPass())
add!(pb, NewPMFunctionPassManager()) do fpm
add!(fpm, SROAPass())
add!(fpm, GVNPass())
end
add!(pb, MergeFunctionsPass())
run!(pb, ir, llvm_machine(job.config.target))
end
end
end
end

# deferred codegen has some special optimization requirements,
# which also need to happen _after_ regular optimization.
# XXX: make these part of the optimizer pipeline?
if has_deferred_jobs
if job.config.cleanup
@tracepoint "clean-up" begin
@dispose pb=NewPMPassBuilder() begin
add!(pb, NewPMFunctionPassManager()) do fpm
add!(fpm, InstCombinePass())
end
add!(pb, AlwaysInlinerPass())
add!(pb, NewPMFunctionPassManager()) do fpm
add!(fpm, SROAPass())
add!(fpm, GVNPass())
end
add!(pb, MergeFunctionsPass())
add!(pb, RecomputeGlobalsAAPass())
add!(pb, GlobalOptPass())
add!(pb, GlobalDCEPass())
add!(pb, StripDeadPrototypesPass())
add!(pb, ConstantMergePass())
run!(pb, ir, llvm_machine(job.config.target))
end
end
end

# optimization may have replaced functions, so look the entry point up again
entry = functions(ir)[entry_fn]
end

if job.config.toplevel && job.config.cleanup
@tracepoint "clean-up" begin
@dispose pb=NewPMPassBuilder() begin
add!(pb, RecomputeGlobalsAAPass())
add!(pb, GlobalOptPass())
add!(pb, GlobalDCEPass())
add!(pb, StripDeadPrototypesPass())
add!(pb, ConstantMergePass())
run!(pb, ir, llvm_machine(job.config.target))
end
end
end

# finish the module
#
# we want to finish the module after optimization, so we cannot do so
# during deferred code generation. instead, process the deferred jobs
# here.
if job.config.toplevel
# finish the module
#
# we want to finish the module after optimization, so we cannot do so
# during deferred code generation. instead, process the deferred jobs
# here.
entry = finish_ir!(job, ir, entry)

for (job′, fn′) in jobs
job′ == job && continue
finish_ir!(job′, ir, functions(ir)[fn′])
Expand All @@ -409,7 +413,7 @@ const __llvm_initialized = Ref(false)
end

if job.config.toplevel && job.config.validate
@tracepoint "Validation" begin
@tracepoint "validation" begin
check_ir(job, ir)
end
end
Expand Down
3 changes: 3 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ link_libraries!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
finish_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module, entry::LLVM.Function) =
entry

# finalization of linked modules, after deferred codegen but before optimization
finish_linked_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module) = return

# post-Julia optimization processing of the module
optimize_module!(@nospecialize(job::CompilerJob), mod::LLVM.Module) = return

Expand Down
33 changes: 20 additions & 13 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -526,18 +526,12 @@ function add_kernel_state!(mod::LLVM.Module)
state_intr = kernel_state_intr(mod, T_state)
state_intr_ft = LLVM.FunctionType(T_state)

kernels = []
kernels_md = metadata(mod)["julia.kernel"]
for kernel_md in operands(kernels_md)
push!(kernels, Value(operands(kernel_md)[1]))
end

# determine which functions need a kernel state argument
#
# previously, we add the argument to every function and relied on unused arg elim to
# clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
# function pointers. such IR is hard to rewrite, so instead be more conservative.
worklist = Set{LLVM.Function}([state_intr, kernels...])
worklist = Set{LLVM.Function}([state_intr, kernels(mod)...])
worklist_length = 0
while worklist_length != length(worklist)
# iteratively discover functions that use the intrinsic or any function calling it
Expand Down Expand Up @@ -941,12 +935,24 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
while worklist_length != length(worklist)
# iteratively discover functions that use an intrinsic or any function calling it
worklist_length = length(worklist)
additions = LLVM.Function[]
for f in worklist, use in uses(f)
inst = user(use)::Instruction
bb = LLVM.parent(inst)
new_f = LLVM.parent(bb)
in(new_f, worklist) || push!(additions, new_f)
additions = Set{LLVM.Function}()
function scan_uses(val)
for use in uses(val)
candidate = user(use)
if isa(candidate, Instruction)
bb = LLVM.parent(candidate)
new_f = LLVM.parent(bb)
in(new_f, worklist) || push!(additions, new_f)
elseif isa(candidate, ConstantExpr)
@safe_info candidate
scan_uses(candidate)
else
error("Don't know how to check uses of $candidate. Please file an issue.")
end
end
end
for f in worklist
scan_uses(f)
end
for f in additions
push!(worklist, f)
Expand Down Expand Up @@ -1054,6 +1060,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
for (f, new_f) in workmap
rewrite_uses!(f, new_f)
@assert isempty(uses(f))
replace_metadata_uses!(f, new_f)
erase!(f)
end

Expand Down
15 changes: 7 additions & 8 deletions src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,11 @@ runtime_slug(job::CompilerJob{MetalCompilerTarget}) = "metal-macos$(job.config.t
isintrinsic(@nospecialize(job::CompilerJob{MetalCompilerTarget}), fn::String) =
return startswith(fn, "air.")

function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module, entry::LLVM.Function)
entry_fn = LLVM.name(entry)

# update calling conventions
if job.config.kernel
entry = pass_by_reference!(job, mod, entry)
entry = add_input_arguments!(job, mod, entry, kernel_intrinsics)
function finish_linked_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::LLVM.Module)
for f in kernels(mod)
# update calling conventions
f = pass_by_reference!(job, mod, f)
f = add_input_arguments!(job, mod, f, kernel_intrinsics)
end

# emit the AIR and Metal version numbers as constants in the module. this makes it
Expand Down Expand Up @@ -83,7 +81,7 @@ function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mo
run!(pb, mod)
end

return functions(mod)[entry_fn]
return
end

function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module)
Expand Down Expand Up @@ -497,6 +495,7 @@ function pass_by_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f
# NOTE: if we ever have legitimate uses of the old function, create a shim instead
fn = LLVM.name(f)
@assert isempty(uses(f))
replace_metadata_uses!(f, new_f)
erase!(f)
LLVM.name!(new_f, fn)

Expand Down
27 changes: 27 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,30 @@ function prune_constexpr_uses!(root::LLVM.Value)
end
end
end


## kernel metadata handling

# kernels are encoded in the IR using the julia.kernel metadata.

# IDEA: don't only mark kernels, but all jobs, and save all attributes of the CompileJob
# so that we can reconstruct the CompileJob instead of setting it globally

# mark a function as kernel
function mark_kernel!(f::LLVM.Function)
mod = LLVM.parent(f)
push!(metadata(mod)["julia.kernel"], MDNode([f]))
return f
end

# iterate over all kernels in the module
function kernels(mod::LLVM.Module)
vals = LLVM.Function[]
if haskey(metadata(mod), "julia.kernel")
kernels_md = metadata(mod)["julia.kernel"]
for kernel_md in operands(kernels_md)
push!(vals, LLVM.Value(operands(kernel_md)[1]))
end
end
return vals
end