Skip to content

Commit

Permalink
Add support for Metal.jl (#76)
Browse files Browse the repository at this point in the history
* add `Metal.jl` support

* minor bug fix

* bug fix for metal

`Metal` is not as fast as `CUDA` for now...
  • Loading branch information
zhenwu0728 authored Jul 24, 2024
1 parent 8202bb6 commit ff7af43
Show file tree
Hide file tree
Showing 29 changed files with 114 additions and 330 deletions.
297 changes: 33 additions & 264 deletions Manifest.toml

Large diffs are not rendered by default.

12 changes: 10 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ version = "0.7.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -17,9 +15,19 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[extensions]
PI_CUDAExt = "CUDA"
PI_MetalExt = ["Metal", "GPUArrays"]

[compat]
Adapt = "^4"
CUDA = "^4, 5"
Metal = "^1.2"
GPUArrays = "10"
JLD2 = "^0.4"
KernelAbstractions = "^0.9"
Expand Down
12 changes: 12 additions & 0 deletions ext/PI_CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module PI_CUDAExt

using CUDA
using CUDA.CUDAKernels
import PlanktonIndividuals.Architectures: GPU, device, array_type, rng_type, isfunctional

device(::GPU) = CUDABackend()
array_type(::GPU) = CuArray
rng_type(::GPU) = CURAND.default_rng()
isfunctional(::GPU) = CUDA.functional()

end
13 changes: 13 additions & 0 deletions ext/PI_MetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module PI_MetalExt

using Metal
using GPUArrays
using Metal.MetalKernels
import PlanktonIndividuals.Architectures: GPU, device, array_type, rng_type, isfunctional

device(::GPU) = MetalBackend()
array_type(::GPU) = MtlArray
rng_type(::GPU) = GPUArrays.default_rng(MtlArray)
isfunctional(::GPU) = Metal.functional()

end
15 changes: 5 additions & 10 deletions src/Architectures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ module Architectures

export CPU, GPU, Architecture
export array_type, rng_type
export device
export device, isfunctional

using CUDA
using GPUArrays
using KernelAbstractions
using CUDA.CUDAKernels

using Random

"""
Expand All @@ -28,14 +26,11 @@ Run PlanktonIndividuals on one CUDA GPU node.
"""
struct GPU <: Architecture end


##### CPU #####
device(::CPU) = KernelAbstractions.CPU()
device(::GPU) = CUDABackend()

array_type(::CPU) = Array
array_type(::GPU) = CuArray

rng_type(::CPU) = MersenneTwister()
rng_type(::GPU) = CURAND.default_rng()
isfunctional(::CPU) = true


end
1 change: 0 additions & 1 deletion src/Biogeochemistry/Biogeochemistry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ export nutrients_init, default_nut_init
export nut_update!

using KernelAbstractions
using CUDA

using PlanktonIndividuals.Architectures: device, array_type, Architecture
using PlanktonIndividuals.Grids
Expand Down
1 change: 0 additions & 1 deletion src/Diagnostics/Diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ export PlanktonDiagnostics
export diags_spcs!, diags_proc!
export tracer_avail_diags, plank_avail_diags

using CUDA
using KernelAbstractions

using PlanktonIndividuals.Architectures: device, Architecture, array_type
Expand Down
1 change: 0 additions & 1 deletion src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ export set_bc!, validate_bcs
export nut_names

using KernelAbstractions
using CUDA

using PlanktonIndividuals.Grids
using PlanktonIndividuals.Architectures: device, Architecture, array_type
Expand Down
1 change: 0 additions & 1 deletion src/Grids/Grids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ export Periodic, Bounded, short_show
export replace_grid_storage
export ΔxC, ΔyC, ΔzC, ΔxF, ΔyF, ΔzF, Ax, Ay, Az, volume

using CUDA
using Adapt

using PlanktonIndividuals.Architectures
Expand Down
1 change: 0 additions & 1 deletion src/Model/Model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module Model
export PlanktonModel
export TimeStep!

using CUDA
using StructArrays
using LinearAlgebra: dot

Expand Down
4 changes: 1 addition & 3 deletions src/Model/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ function PlanktonModel(arch::Architecture, grid::AbstractGrid;

@assert maximum(N_individual) max_individuals

if arch == GPU() && !has_cuda()
throw(ArgumentError("Cannot create a GPU model. No CUDA-enabled GPU was detected!"))
end
@assert isfunctional(arch) == true

grid_d = replace_grid_storage(arch, grid)

Expand Down
7 changes: 4 additions & 3 deletions src/Model/time_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ function TimeStep!(model::PlanktonModel, ΔT, diags::PlanktonDiagnostics)
end

##### calculate PAR
calc_par!(model.timestepper.par, model.arch, model.timestepper.Chl, model.timestepper.PARF,
model.grid, model.bgc_params["kc"], model.bgc_params["kw"])

for ki in 1:model.grid.Nz
calc_par!(model.timestepper.par, model.arch, model.timestepper.Chl, model.timestepper.PARF,
model.grid, model.bgc_params["kc"], model.bgc_params["kw"], ki)
end
##### diagnostics for nutrients
@inbounds diags.tracer.PAR .+= model.timestepper.par
@inbounds diags.tracer.T .+= model.timestepper.temp
Expand Down
5 changes: 3 additions & 2 deletions src/Model/timestepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ function timestepper(arch::Architecture, FT::DataType, g::AbstractGrid, maxN)
velos_d = replace_storage(array_type(arch), velos)

nuts = StructArray(NH4 = zeros(FT, maxN), NO3 = zeros(FT, maxN), PO4 = zeros(FT, maxN),
DOC = zeros(FT, maxN), FeT = zeros(FT, maxN), idc = zeros(Int,maxN),
par = zeros(FT, maxN), T = zeros(FT, maxN), pop = zeros(FT, maxN))
DOC = zeros(FT, maxN), FeT = zeros(FT, maxN), par = zeros(FT, maxN),
T = zeros(FT, maxN), pop = zeros(FT, maxN), idc = zeros(FT, maxN),
idc_int = zeros(Int, maxN))
nuts_d = replace_storage(array_type(arch), nuts)

ts = timestepper(Gcs, nut_temp, vel₀, vel½, vel₁, PARF, temp, plk, par, Chl, pop, rnd_d, velos_d, nuts_d)
Expand Down
1 change: 0 additions & 1 deletion src/Parameters/Parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ export bgc_params_default, phyt_params_default
export default_PARF, default_temperature
export update_bgc_params, update_phyt_params

using CUDA
using PlanktonIndividuals.Grids

using PlanktonIndividuals: AbstractMode, CarbonMode, QuotaMode, MacroMolecularMode, IronEnergyMode
Expand Down
2 changes: 1 addition & 1 deletion src/Parameters/param_default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function default_temperature(grid, ΔT, iterations)
end
# vertical temperature gradient
for j in grid.Nz-1:-1:1
CUDA.@allowscalar temp_day[:,:,j,:] .= temp_day[:,:,j+1,:] .+ 4.0e-4 * grid.zC[j]
temp_day[:,:,j,:] .= temp_day[:,:,j+1,:] .+ 4.0e-4 * grid.zC[j]
end
temp_domain = repeat(temp_day, outer = (1,1,1,total_days))
return temp_domain
Expand Down
1 change: 0 additions & 1 deletion src/Plankton/Advection/Advection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ export plankton_diffusion!
export get_xf_index, get_yf_index, get_zf_index

using KernelAbstractions
using CUDA
using Random

using PlanktonIndividuals.Architectures: device, Architecture, rng_type
Expand Down
1 change: 0 additions & 1 deletion src/Plankton/CarbonMode/CarbonMode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ export plankton_update!
export construct_plankton, generate_plankton!

using KernelAbstractions
using CUDA
using StructArrays
using Random
using LinearAlgebra: dot
Expand Down
5 changes: 3 additions & 2 deletions src/Plankton/CarbonMode/division_death.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ function divide_to_half!(plank, arch)
return nothing
end
function divide!(plank, nuts, deactive_ind, arch::Architecture)
accumulate!(+, nuts.idc, Int.(plank.dvid))
get_tind!(plank.idx, plank.dvid, nuts.idc, deactive_ind, arch)
accumulate!(+, nuts.idc, plank.dvid)
nuts.idc_int .= unsafe_trunc.(Int, nuts.idc)
get_tind!(plank.idx, plank.dvid, nuts.idc_int, deactive_ind, arch)
copy_daughter_individuals!(plank, plank.dvid, plank.idx, arch)
divide_to_half!(plank, arch)
return nothing
Expand Down
1 change: 0 additions & 1 deletion src/Plankton/IronEnergyMode/IronEnergyMode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ export plankton_update!
export construct_plankton, generate_plankton!

using KernelAbstractions
using CUDA
using StructArrays
using Random
using LinearAlgebra: dot
Expand Down
5 changes: 3 additions & 2 deletions src/Plankton/IronEnergyMode/division_death.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ function divide_to_half!(plank, arch)
return nothing
end
function divide!(plank, nuts, deactive_ind, arch::Architecture)
accumulate!(+, nuts.idc, Int.(plank.dvid))
get_tind!(plank.idx, plank.dvid, nuts.idc, deactive_ind, arch)
accumulate!(+, nuts.idc, plank.dvid)
nuts.idc_int .= unsafe_trunc.(Int, nuts.idc)
get_tind!(plank.idx, plank.dvid, nuts.idc_int, deactive_ind, arch)
copy_daughter_individuals!(plank, plank.dvid, plank.idx, arch)
divide_to_half!(plank, arch)
return nothing
Expand Down
1 change: 0 additions & 1 deletion src/Plankton/MacroMolecularMode/MacroMolecularMode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ export plankton_update!
export construct_plankton, generate_plankton!

using KernelAbstractions
using CUDA
using StructArrays
using Random
using LinearAlgebra: dot
Expand Down
5 changes: 3 additions & 2 deletions src/Plankton/MacroMolecularMode/division_death.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ function divide_to_half!(plank, arch)
return nothing
end
function divide!(plank, nuts, deactive_ind, arch::Architecture)
accumulate!(+, nuts.idc, Int.(plank.dvid))
get_tind!(plank.idx, plank.dvid, nuts.idc, deactive_ind, arch)
accumulate!(+, nuts.idc, plank.dvid)
nuts.idc_int .= unsafe_trunc.(Int, nuts.idc)
get_tind!(plank.idx, plank.dvid, nuts.idc_int, deactive_ind, arch)
copy_daughter_individuals!(plank, plank.dvid, plank.idx, arch)
divide_to_half!(plank, arch)
return nothing
Expand Down
1 change: 0 additions & 1 deletion src/Plankton/Plankton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ export plankton_update!
export generate_individuals, individuals
export find_inds!, find_NPT!, acc_counts!, calc_par!

using CUDA
using StructArrays
using Random
using KernelAbstractions
Expand Down
1 change: 0 additions & 1 deletion src/Plankton/QuotaMode/QuotaMode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ export plankton_update!
export construct_plankton, generate_plankton!

using KernelAbstractions
using CUDA
using StructArrays
using Random
using LinearAlgebra: dot
Expand Down
5 changes: 3 additions & 2 deletions src/Plankton/QuotaMode/division_death.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ function divide_to_half!(plank, arch)
return nothing
end
function divide!(plank, nuts, deactive_ind, arch::Architecture)
accumulate!(+, nuts.idc, Int.(plank.dvid))
get_tind!(plank.idx, plank.dvid, nuts.idc, deactive_ind, arch)
accumulate!(+, nuts.idc, plank.dvid)
nuts.idc_int .= unsafe_trunc.(Int, nuts.idc)
get_tind!(plank.idx, plank.dvid, nuts.idc_int, deactive_ind, arch)
copy_daughter_individuals!(plank, plank.dvid, plank.idx, arch)
divide_to_half!(plank, arch)
return nothing
Expand Down
20 changes: 9 additions & 11 deletions src/Plankton/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,18 @@ function acc_counts!(ctsChl, ctspop, Chl, ac, x, y, z, arch)
end

##### calculate PAR field based on Chla field and depth
@kernel function calc_par_kernel!(par, Chl, PARF, g::AbstractGrid, kc, kw)
@kernel function calc_par_kernel!(par, Chl, PARF, g::AbstractGrid, kc, kw, ki)
i, j = @index(Global, NTuple)
for k in 1:g.Nz
ii = i + g.Hx
jj = j + g.Hy
kk = k + g.Hz
atten = (Chl[ii,jj,kk]/volume(ii, jj, kk, g) * kc + kw) * ΔzF(ii, jj, kk, g)
par[ii,jj,kk] = PARF[i,j] * (1.0f0 - exp(-atten)) / atten
PARF[i,j] = PARF[i,j] * exp(-atten)
end
ii = i + g.Hx
jj = j + g.Hy
kk = ki + g.Hz
atten = (Chl[ii,jj,kk]/volume(ii, jj, kk, g) * kc + kw) * ΔzF(ii, jj, kk, g)
par[ii,jj,kk] = PARF[i,j] * (1.0f0 - exp(-atten)) / atten
PARF[i,j] = PARF[i,j] * exp(-atten)
end
function calc_par!(par, arch::Architecture, Chl, PARF, g::AbstractGrid, kc, kw)
function calc_par!(par, arch::Architecture, Chl, PARF, g::AbstractGrid, kc, kw, ki)
kernel! = calc_par_kernel!(device(arch), (16,16), (g.Nx, g.Ny))
kernel!(par, Chl, PARF, g, kc, kw)
kernel!(par, Chl, PARF, g, kc, kw, ki)
return nothing
end

Expand Down
11 changes: 0 additions & 11 deletions src/PlanktonIndividuals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ export
seconds, minutes, hours, meters, kilometers,
KiB, MiB, GiB, TiB

using CUDA
using Pkg.Artifacts

import Base: show
Expand Down Expand Up @@ -107,14 +106,4 @@ using .Output
using .Simulation
using .Units

function __init__()
if CUDA.has_cuda()
@debug "CUDA-enabled GPU detected:"
for (gpu, dev) in enumerate(CUDA.devices())
@debug "$dev: $(CUDA.name(dev))"
end
CUDA.allowscalar(false)
end
end

end # module
1 change: 0 additions & 1 deletion src/Simulation/Simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module Simulation
export PlanktonSimulation
export update!, vel_copy!, set_vels_fields!, set_PARF_fields!, set_temp_fields!

using CUDA
using StructArrays
using LinearAlgebra: dot

Expand Down
13 changes: 11 additions & 2 deletions src/Simulation/simulations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ Keyword Arguments (Optional)
- `output_writer` : Output writer of the simulation generated by `PlanktonOutputWriter`.
"""
function PlanktonSimulation(model::PlanktonModel; ΔT::AbstractFloat, iterations::Int,
PARF = default_PARF(model.grid, ΔT, iterations),
temp = default_temperature(model.grid, ΔT, iterations),
PARF = nothing,
temp = nothing,
diags = nothing,
vels = (;),
ΔT_vel::AbstractFloat = ΔT,
Expand All @@ -67,6 +67,15 @@ function PlanktonSimulation(model::PlanktonModel; ΔT::AbstractFloat, iterations
vels_ft = (;)
end

if isa(PARF, Nothing)
PARF = default_PARF(model.grid, ΔT, iterations)
end

if isa(temp, Nothing)
grid = replace_grid_storage(CPU(), model.grid)
temp = default_temperature(grid, ΔT, iterations)
end

input = PlanktonInput(FT.(temp), FT.(PARF), vels_ft, FT(ΔT_vel), FT(ΔT_PAR), FT(ΔT_temp))

validate_bcs(model.nutrients, model.grid, iterations)
Expand Down

0 comments on commit ff7af43

Please sign in to comment.