Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -921,3 +921,160 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M
return new_f
end
end

function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
entry::LLVM.Function, kernel_intrinsics::Dict)
entry_fn = LLVM.name(entry)

# figure out which intrinsics are used and need to be added as arguments
used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn
haskey(functions(mod), intr_fn)
end |> collect
nargs = length(used_intrinsics)

# determine which functions need these arguments
worklist = Set{LLVM.Function}([entry])
for intr_fn in used_intrinsics
push!(worklist, functions(mod)[intr_fn])
end
worklist_length = 0
while worklist_length != length(worklist)
# iteratively discover functions that use an intrinsic or any function calling it
worklist_length = length(worklist)
additions = LLVM.Function[]
for f in worklist, use in uses(f)
inst = user(use)::Instruction
bb = LLVM.parent(inst)
new_f = LLVM.parent(bb)
in(new_f, worklist) || push!(additions, new_f)
end
for f in additions
push!(worklist, f)
end
end
for intr_fn in used_intrinsics
delete!(worklist, functions(mod)[intr_fn])
end

# add the arguments
# NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
workmap = Dict{LLVM.Function, LLVM.Function}()
for f in worklist
fn = LLVM.name(f)
ft = function_type(f)
LLVM.name!(f, fn * ".orig")
# create a new function
new_param_types = LLVMType[parameters(ft)...]

for intr_fn in used_intrinsics
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ)
push!(new_param_types, llvm_typ)
end
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
new_f = LLVM.Function(mod, fn, new_ft)
linkage!(new_f, linkage(f))
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
LLVM.name!(new_arg, LLVM.name(arg))
end
for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
end

workmap[f] = new_f
end

# clone and rewrite the function bodies.
# we don't need to rewrite much as the arguments are added last.
for (f, new_f) in workmap
# map the arguments
value_map = Dict{LLVM.Value, LLVM.Value}()
for (param, new_param) in zip(parameters(f), parameters(new_f))
LLVM.name!(new_param, LLVM.name(param))
value_map[param] = new_param
end

value_map[f] = new_f
clone_into!(new_f, f; value_map,
changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)

# we can't remove this function yet, as we might still need to rewrite any called,
# but remove the IR already
empty!(f)
end

# drop unused constants that may be referring to the old functions
# XXX: can we do this differently?
for f in worklist
prune_constexpr_uses!(f)
end

# update other uses of the old function, modifying call sites to pass the arguments
function rewrite_uses!(f, new_f)
# update uses
@dispose builder=IRBuilder() begin
for use in uses(f)
val = user(use)
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
callee_f = LLVM.parent(LLVM.parent(val))
# forward the arguments
position!(builder, val)
new_val = if val isa LLVM.CallInst
call!(builder, function_type(new_f), new_f,
[arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
operand_bundles(val))
else
# TODO: invoke and callbr
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
end
callconv!(new_val, callconv(val))

replace_uses!(val, new_val)
@assert isempty(uses(val))
erase!(val)
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
# XXX: why isn't this caught by the value materializer above?
target = operands(val)[1]
@assert target == f
new_val = LLVM.const_bitcast(new_f, value_type(val))
rewrite_uses!(val, new_val)
# we can't simply replace this constant expression, as it may be used
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)

# drop the old constant if it is unused
# XXX: can we do this differently?
if isempty(uses(val))
LLVM.unsafe_destroy!(val)
end
else
error("Cannot rewrite unknown use of function: $val")
end
end
end
end
for (f, new_f) in workmap
rewrite_uses!(f, new_f)
@assert isempty(uses(f))
erase!(f)
end

# replace uses of the intrinsics with references to the input arguments
for (i, intr_fn) in enumerate(used_intrinsics)
intr = functions(mod)[intr_fn]
for use in uses(intr)
val = user(use)
callee_f = LLVM.parent(LLVM.parent(val))
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
replace_uses!(val, parameters(callee_f)[end-nargs+i])
else
error("Cannot rewrite unknown use of function: $val")
end

@assert isempty(uses(val))
erase!(val)
end
@assert isempty(uses(intr))
erase!(intr)
end

return functions(mod)[entry_fn]
end
162 changes: 1 addition & 161 deletions src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mo
# update calling conventions
if job.config.kernel
entry = pass_by_reference!(job, mod, entry)

add_input_arguments!(job, mod, entry)
entry = LLVM.functions(mod)[entry_fn]
entry = add_input_arguments!(job, mod, entry, kernel_intrinsics)
end

# emit the AIR and Metal version numbers as constants in the module. this makes it
Expand Down Expand Up @@ -553,164 +551,6 @@ function argument_type_name(typ)
end
end

function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
entry::LLVM.Function)
entry_fn = LLVM.name(entry)

# figure out which intrinsics are used and need to be added as arguments
used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn
haskey(functions(mod), intr_fn)
end |> collect
nargs = length(used_intrinsics)

# determine which functions need these arguments
worklist = Set{LLVM.Function}([entry])
for intr_fn in used_intrinsics
push!(worklist, functions(mod)[intr_fn])
end
worklist_length = 0
while worklist_length != length(worklist)
# iteratively discover functions that use an intrinsic or any function calling it
worklist_length = length(worklist)
additions = LLVM.Function[]
for f in worklist, use in uses(f)
inst = user(use)::Instruction
bb = LLVM.parent(inst)
new_f = LLVM.parent(bb)
in(new_f, worklist) || push!(additions, new_f)
end
for f in additions
push!(worklist, f)
end
end
for intr_fn in used_intrinsics
delete!(worklist, functions(mod)[intr_fn])
end

# add the arguments
# NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
workmap = Dict{LLVM.Function, LLVM.Function}()
for f in worklist
fn = LLVM.name(f)
ft = function_type(f)
LLVM.name!(f, fn * ".orig")
# create a new function
new_param_types = LLVMType[parameters(ft)...]

for intr_fn in used_intrinsics
llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ)
push!(new_param_types, llvm_typ)
end
new_ft = LLVM.FunctionType(return_type(ft), new_param_types)
new_f = LLVM.Function(mod, fn, new_ft)
linkage!(new_f, linkage(f))
for (arg, new_arg) in zip(parameters(f), parameters(new_f))
LLVM.name!(new_arg, LLVM.name(arg))
end
for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end])
LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name)
end

workmap[f] = new_f
end

# clone and rewrite the function bodies.
# we don't need to rewrite much as the arguments are added last.
for (f, new_f) in workmap
# map the arguments
value_map = Dict{LLVM.Value, LLVM.Value}()
for (param, new_param) in zip(parameters(f), parameters(new_f))
LLVM.name!(new_param, LLVM.name(param))
value_map[param] = new_param
end

value_map[f] = new_f
clone_into!(new_f, f; value_map,
changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly)

# we can't remove this function yet, as we might still need to rewrite any called,
# but remove the IR already
empty!(f)
end

# drop unused constants that may be referring to the old functions
# XXX: can we do this differently?
for f in worklist
prune_constexpr_uses!(f)
end

# update other uses of the old function, modifying call sites to pass the arguments
function rewrite_uses!(f, new_f)
# update uses
@dispose builder=IRBuilder() begin
for use in uses(f)
val = user(use)
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
callee_f = LLVM.parent(LLVM.parent(val))
# forward the arguments
position!(builder, val)
new_val = if val isa LLVM.CallInst
call!(builder, function_type(new_f), new_f,
[arguments(val)..., parameters(callee_f)[end-nargs+1:end]...],
operand_bundles(val))
else
# TODO: invoke and callbr
error("Rewrite of $(typeof(val))-based calls is not implemented: $val")
end
callconv!(new_val, callconv(val))

replace_uses!(val, new_val)
@assert isempty(uses(val))
erase!(val)
elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast
# XXX: why isn't this caught by the value materializer above?
target = operands(val)[1]
@assert target == f
new_val = LLVM.const_bitcast(new_f, value_type(val))
rewrite_uses!(val, new_val)
# we can't simply replace this constant expression, as it may be used
# as a call, taking arguments (so we need to rewrite it to pass the input arguments)

# drop the old constant if it is unused
# XXX: can we do this differently?
if isempty(uses(val))
LLVM.unsafe_destroy!(val)
end
else
error("Cannot rewrite unknown use of function: $val")
end
end
end
end
for (f, new_f) in workmap
rewrite_uses!(f, new_f)
@assert isempty(uses(f))
erase!(f)
end

# replace uses of the intrinsics with references to the input arguments
for (i, intr_fn) in enumerate(used_intrinsics)
intr = functions(mod)[intr_fn]
for use in uses(intr)
val = user(use)
callee_f = LLVM.parent(LLVM.parent(val))
if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst
replace_uses!(val, parameters(callee_f)[end-nargs+i])
else
error("Cannot rewrite unknown use of function: $val")
end

@assert isempty(uses(val))
erase!(val)
end
@assert isempty(uses(intr))
erase!(intr)
end

return
end


# argument metadata generation
#
# module metadata is used to identify buffers that are passed as kernel arguments.
Expand Down
Loading