@@ -82,7 +82,7 @@ function irgen(@nospecialize(job::CompilerJob))
8282
8383    #  minimal required optimization
8484    @tracepoint  " rewrite" begin 
85-         if  job. config. kernel &&  needs_byval (job)
85+         if  job. config. kernel &&  pass_by_value (job)
8686            #  pass all bitstypes by value; by default Julia passes aggregates by reference
8787            #  (this improves performance, and is mandated by certain back-ends like SPIR-V).
8888            args =  classify_arguments (job, function_type (entry))
@@ -256,10 +256,11 @@ end
256256# # kernel promotion
257257
258258@enum  ArgumentCC begin 
259-     BITS_VALUE  #  bitstype, passed as value
260-     BITS_REF    #  bitstype, passed as pointer
261-     MUT_REF     #  jl_value_t*, or the anonymous equivalent
262-     GHOST       #  not passed
259+     BITS_VALUE      #  bitstype, passed as value
260+     BITS_REF        #  bitstype, passed as pointer
261+     MUT_REF         #  jl_value_t*, or the anonymous equivalent
262+     GHOST           #  not passed
263+     KERNEL_STATE    #  the kernel state argument
263264end 
264265
265266#  Determine the calling convention of a the arguments of a Julia function, given the
270271#  - `name`: the name of the argument
271272#  - `idx`: the index of the argument in the LLVM function type, or `nothing` if the argument
272273#           is not passed at the LLVM level.
273- function  classify_arguments (@nospecialize (job:: CompilerJob ), codegen_ft:: LLVM.FunctionType )
274+ function  classify_arguments (@nospecialize (job:: CompilerJob ), codegen_ft:: LLVM.FunctionType ;
275+                             post_optimization:: Bool = false )
274276    source_sig =  job. source. specTypes
275277    source_types =  [source_sig. parameters... ]
276278
@@ -282,9 +284,15 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu
282284
283285    codegen_types =  parameters (codegen_ft)
284286
285-     args =  []
286-     codegen_i =  1 
287-     for  (source_i, (source_typ, source_name)) in  enumerate (zip (source_types, source_argnames))
287+     if  post_optimization &&  kernel_state_type (job) != =  Nothing
288+         args =  []
289+         push! (args, (cc= KERNEL_STATE, typ= kernel_state_type (job), name= :kernel_state , idx= 1 ))
290+         codegen_i =  2 
291+     else 
292+         args =  []
293+         codegen_i =  1 
294+     end 
295+     for  (source_typ, source_name) in  zip (source_types, source_argnames)
288296        if  isghosttype (source_typ) ||  Core. Compiler. isconstType (source_typ)
289297            push! (args, (cc= GHOST, typ= source_typ, name= source_name, idx= nothing ))
290298            continue 
@@ -817,3 +825,256 @@ function kernel_state_value(state)
817825        call_function (llvm_f, state)
818826    end 
819827end 
828+ 
829+ #  convert kernel state argument from pass-by-value to pass-by-reference
830+ # 
831+ #  the kernel state argument is always passed by value to avoid codegen issues with byval.
832+ #  some back-ends however do not support passing kernel arguments by value, so this pass
833+ #  serves to convert that argument (and is conceptually the inverse of `lower_byval`).
834+ function  kernel_state_to_reference! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
835+                                     f:: LLVM.Function )
836+     ft =  function_type (f)
837+ 
838+     #  check if we even need a kernel state argument
839+     state =  kernel_state_type (job)
840+     if  state ===  Nothing
841+         return  f
842+     end 
843+ 
844+     T_state =  convert (LLVMType, state)
845+ 
846+     #  find the kernel state parameter (should be the first argument)
847+     if  isempty (parameters (ft)) ||  value_type (parameters (f)[1 ]) !=  T_state
848+         return  f
849+     end 
850+ 
851+     @tracepoint  " kernel state to reference" begin 
852+         #  generate the new function type & definition
853+         new_types =  LLVM. LLVMType[]
854+         #  convert the first parameter (kernel state) to a pointer
855+         push! (new_types, LLVM. PointerType (T_state))
856+         #  keep all other parameters as-is
857+         for  i in  2 : length (parameters (ft))
858+             push! (new_types, parameters (ft)[i])
859+         end 
860+ 
861+         new_ft =  LLVM. FunctionType (return_type (ft), new_types)
862+         new_f =  LLVM. Function (mod, " " 
863+         linkage! (new_f, linkage (f))
864+ 
865+         #  name the parameters
866+         LLVM. name! (parameters (new_f)[1 ], " state_ptr" 
867+         for  (i, (arg, new_arg)) in  enumerate (zip (parameters (f)[2 : end ], parameters (new_f)[2 : end ]))
868+             LLVM. name! (new_arg, LLVM. name (arg))
869+         end 
870+ 
871+         #  emit IR performing the "conversions"
872+         new_args =  LLVM. Value[]
873+         @dispose  builder= IRBuilder () begin 
874+             entry =  BasicBlock (new_f, " conversion" 
875+             position! (builder, entry)
876+ 
877+             #  load the kernel state value from the pointer
878+             state_val =  load! (builder, T_state, parameters (new_f)[1 ], " state" 
879+             push! (new_args, state_val)
880+ 
881+             #  all other arguments are passed through directly
882+             for  i in  2 : length (parameters (new_f))
883+                 push! (new_args, parameters (new_f)[i])
884+             end 
885+ 
886+             #  map the arguments
887+             value_map =  Dict {LLVM.Value, LLVM.Value} (
888+                 param =>  new_args[i] for  (i,param) in  enumerate (parameters (f))
889+             )
890+             value_map[f] =  new_f
891+ 
892+             clone_into! (new_f, f; value_map,
893+                         changes= LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges)
894+ 
895+             #  fall through
896+             br! (builder, blocks (new_f)[2 ])
897+         end 
898+ 
899+         #  set the attributes for the state pointer parameter
900+         attrs =  parameter_attributes (new_f, 1 )
901+         #  the pointer itself cannot be captured since we immediately load from it
902+         push! (attrs, EnumAttribute (" nocapture" 0 ))
903+         #  each kernel state is separate
904+         push! (attrs, EnumAttribute (" noalias" 0 ))
905+         #  the state is read-only
906+         push! (attrs, EnumAttribute (" readonly" 0 ))
907+ 
908+         #  remove the old function
909+         fn =  LLVM. name (f)
910+         @assert  isempty (uses (f))
911+         replace_metadata_uses! (f, new_f)
912+         erase! (f)
913+         LLVM. name! (new_f, fn)
914+ 
915+         #  minimal optimization
916+         @dispose  pb= NewPMPassBuilder () begin 
917+             add! (pb, SimplifyCFGPass ())
918+             run! (pb, new_f, llvm_machine (job. config. target))
919+         end 
920+ 
921+         return  new_f
922+     end 
923+ end 
924+ 
925+ function  add_input_arguments! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
926+                               entry:: LLVM.Function , kernel_intrinsics:: Dict )
927+     entry_fn =  LLVM. name (entry)
928+ 
929+     #  figure out which intrinsics are used and need to be added as arguments
930+     used_intrinsics =  filter (keys (kernel_intrinsics)) do  intr_fn
931+         haskey (functions (mod), intr_fn)
932+     end  |>  collect
933+     nargs =  length (used_intrinsics)
934+ 
935+     #  determine which functions need these arguments
936+     worklist =  Set {LLVM.Function} ([entry])
937+     for  intr_fn in  used_intrinsics
938+         push! (worklist, functions (mod)[intr_fn])
939+     end 
940+     worklist_length =  0 
941+     while  worklist_length !=  length (worklist)
942+         #  iteratively discover functions that use an intrinsic or any function calling it
943+         worklist_length =  length (worklist)
944+         additions =  LLVM. Function[]
945+         for  f in  worklist, use in  uses (f)
946+             inst =  user (use):: Instruction 
947+             bb =  LLVM. parent (inst)
948+             new_f =  LLVM. parent (bb)
949+             in (new_f, worklist) ||  push! (additions, new_f)
950+         end 
951+         for  f in  additions
952+             push! (worklist, f)
953+         end 
954+     end 
955+     for  intr_fn in  used_intrinsics
956+         delete! (worklist, functions (mod)[intr_fn])
957+     end 
958+ 
959+     #  add the arguments
960+     #  NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
961+     workmap =  Dict {LLVM.Function, LLVM.Function} ()
962+     for  f in  worklist
963+         fn =  LLVM. name (f)
964+         ft =  function_type (f)
965+         LLVM. name! (f, fn *  " .orig" 
966+         #  create a new function
967+         new_param_types =  LLVMType[parameters (ft)... ]
968+ 
969+         for  intr_fn in  used_intrinsics
970+             llvm_typ =  convert (LLVMType, kernel_intrinsics[intr_fn]. typ)
971+             push! (new_param_types, llvm_typ)
972+         end 
973+         new_ft =  LLVM. FunctionType (return_type (ft), new_param_types)
974+         new_f =  LLVM. Function (mod, fn, new_ft)
975+         linkage! (new_f, linkage (f))
976+         for  (arg, new_arg) in  zip (parameters (f), parameters (new_f))
977+             LLVM. name! (new_arg, LLVM. name (arg))
978+         end 
979+         for  (intr_fn, new_arg) in  zip (used_intrinsics, parameters (new_f)[end - nargs+ 1 : end ])
980+             LLVM. name! (new_arg, kernel_intrinsics[intr_fn]. name)
981+         end 
982+ 
983+         workmap[f] =  new_f
984+     end 
985+ 
986+     #  clone and rewrite the function bodies.
987+     #  we don't need to rewrite much as the arguments are added last.
988+     for  (f, new_f) in  workmap
989+         #  map the arguments
990+         value_map =  Dict {LLVM.Value, LLVM.Value} ()
991+         for  (param, new_param) in  zip (parameters (f), parameters (new_f))
992+             LLVM. name! (new_param, LLVM. name (param))
993+             value_map[param] =  new_param
994+         end 
995+ 
996+         value_map[f] =  new_f
997+         clone_into! (new_f, f; value_map,
998+                     changes= LLVM. API. LLVMCloneFunctionChangeTypeLocalChangesOnly)
999+ 
1000+         #  we can't remove this function yet, as we might still need to rewrite any called,
1001+         #  but remove the IR already
1002+         empty! (f)
1003+     end 
1004+ 
1005+     #  drop unused constants that may be referring to the old functions
1006+     #  XXX : can we do this differently?
1007+     for  f in  worklist
1008+         prune_constexpr_uses! (f)
1009+     end 
1010+ 
1011+     #  update other uses of the old function, modifying call sites to pass the arguments
1012+     function  rewrite_uses! (f, new_f)
1013+         #  update uses
1014+         @dispose  builder= IRBuilder () begin 
1015+             for  use in  uses (f)
1016+                 val =  user (use)
1017+                 if  val isa  LLVM. CallInst ||  val isa  LLVM. InvokeInst ||  val isa  LLVM. CallBrInst
1018+                     callee_f =  LLVM. parent (LLVM. parent (val))
1019+                     #  forward the arguments
1020+                     position! (builder, val)
1021+                     new_val =  if  val isa  LLVM. CallInst
1022+                         call! (builder, function_type (new_f), new_f,
1023+                               [arguments (val)... , parameters (callee_f)[end - nargs+ 1 : end ]. .. ],
1024+                               operand_bundles (val))
1025+                     else 
1026+                         #  TODO : invoke and callbr
1027+                         error (" Rewrite of $(typeof (val)) -based calls is not implemented: $val " 
1028+                     end 
1029+                     callconv! (new_val, callconv (val))
1030+ 
1031+                     replace_uses! (val, new_val)
1032+                     @assert  isempty (uses (val))
1033+                     erase! (val)
1034+                 elseif  val isa  LLVM. ConstantExpr &&  opcode (val) ==  LLVM. API. LLVMBitCast
1035+                     #  XXX : why isn't this caught by the value materializer above?
1036+                     target =  operands (val)[1 ]
1037+                     @assert  target ==  f
1038+                     new_val =  LLVM. const_bitcast (new_f, value_type (val))
1039+                     rewrite_uses! (val, new_val)
1040+                     #  we can't simply replace this constant expression, as it may be used
1041+                     #  as a call, taking arguments (so we need to rewrite it to pass the input arguments)
1042+ 
1043+                     #  drop the old constant if it is unused
1044+                     #  XXX : can we do this differently?
1045+                     if  isempty (uses (val))
1046+                         LLVM. unsafe_destroy! (val)
1047+                     end 
1048+                 else 
1049+                     error (" Cannot rewrite unknown use of function: $val " 
1050+                 end 
1051+             end 
1052+         end 
1053+     end 
1054+     for  (f, new_f) in  workmap
1055+         rewrite_uses! (f, new_f)
1056+         @assert  isempty (uses (f))
1057+         erase! (f)
1058+     end 
1059+ 
1060+     #  replace uses of the intrinsics with references to the input arguments
1061+     for  (i, intr_fn) in  enumerate (used_intrinsics)
1062+         intr =  functions (mod)[intr_fn]
1063+         for  use in  uses (intr)
1064+             val =  user (use)
1065+             callee_f =  LLVM. parent (LLVM. parent (val))
1066+             if  val isa  LLVM. CallInst ||  val isa  LLVM. InvokeInst ||  val isa  LLVM. CallBrInst
1067+                 replace_uses! (val, parameters (callee_f)[end - nargs+ i])
1068+             else 
1069+                 error (" Cannot rewrite unknown use of function: $val " 
1070+             end 
1071+ 
1072+             @assert  isempty (uses (val))
1073+             erase! (val)
1074+         end 
1075+         @assert  isempty (uses (intr))
1076+         erase! (intr)
1077+     end 
1078+ 
1079+     return  functions (mod)[entry_fn]
1080+ end 
0 commit comments