Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support both Float32 and Float64 #69

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions src/Biogeochemistry/nutrient_fields.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
function nutrients_init(arch, g)
fields = (Field(arch, g), Field(arch, g),
Field(arch, g), Field(arch, g),
Field(arch, g), Field(arch, g),
Field(arch, g), Field(arch, g),
Field(arch, g), Field(arch, g))
function nutrients_init(arch, g, FT = Float32)
fields = (Field(arch, g, FT), Field(arch, g, FT),
Field(arch, g, FT), Field(arch, g, FT),
Field(arch, g, FT), Field(arch, g, FT),
Field(arch, g, FT), Field(arch, g, FT),
Field(arch, g, FT), Field(arch, g, FT))

nut = NamedTuple{nut_names}(fields)
return nut
Expand All @@ -20,19 +20,22 @@ function default_nut_init()
end

"""
generate_nutrients(arch, grid, source)
generate_nutrients(arch, grid, source, FT)
Set up initial nutrient fields according to `grid`.

Keyword Arguments
Arguments
=================
- `arch`: `CPU()` or `GPU()`. The computer architecture used to time-step `model`.
- `grid`: The resolution and discrete geometry on which nutrient fields are solved.
- `source`: A `NamedTuple` containing 10 numbers each of which is the uniform initial condition of one tracer,
or a `Dict` containing the file paths pointing to the files of nutrient initial conditions.
- `source`: A `NamedTuple` containing 10 numbers each of which is the uniform initial
condition of one tracer, or a `Dict` containing the file paths pointing to
the files of nutrient initial conditions.
- `FT`: Floating point data type. Default: `Float32`.
"""
function generate_nutrients(arch, g, source::Union{Dict,NamedTuple})
function generate_nutrients(arch::Architecture, g::AbstractGrid,
source::Union{Dict,NamedTuple}, FT::DataType)
total_size = (g.Nx+g.Hx*2, g.Ny+g.Hy*2, g.Nz+g.Hz*2)
nut = nutrients_init(arch, g)
nut = nutrients_init(arch, g, FT)
pathkeys = collect(keys(source))

if typeof(source) <: NamedTuple
Expand All @@ -58,9 +61,9 @@ function generate_nutrients(arch, g, source::Union{Dict,NamedTuple})
if source.initial_condition[name] < 0.0
throw(ArgumentError("NUT_INIT: The initial condition should be none-negetive."))
end
lower = 1.0 - source.rand_noise[name]
upper = 1.0 + source.rand_noise[name]
nut[name].data .= fill(source.initial_condition[name],total_size) .* rand(lower:1e-4:upper, total_size) |> array_type(arch)
lower = FT(1.0 - source.rand_noise[name])
upper = FT(1.0 + source.rand_noise[name])
nut[name].data .= fill(FT(source.initial_condition[name]),total_size) .* rand(lower:1e-4:upper, total_size) |> array_type(arch)
end
end

Expand Down
15 changes: 8 additions & 7 deletions src/Diagnostics/diagnostics_struct.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mutable struct PlanktonDiagnostics
plankton::NamedTuple # for each species
tracer::NamedTuple # for tracers
iteration_interval::Int64 # time interval that the diagnostics is time averaged
iteration_interval::Int # time interval that the diagnostics is time averaged
end

"""
Expand All @@ -19,7 +19,7 @@ Keyword Arguments (Optional)
"""
function PlanktonDiagnostics(model; tracer=(),
plankton=(:num, :graz, :mort, :dvid),
iteration_interval::Int64 = 1)
iteration_interval::Int = 1)

@assert isa(tracer, Tuple)
@assert isa(plankton, Tuple)
Expand All @@ -30,15 +30,16 @@ function PlanktonDiagnostics(model; tracer=(),
nproc = length(plankton)
trs = []
procs = []
FT = model.FT

total_size = (model.grid.Nx+model.grid.Hx*2, model.grid.Ny+model.grid.Hy*2, model.grid.Nz+model.grid.Hz*2)

for i in 1:ntr
tr = zeros(total_size) |> array_type(model.arch)
tr = zeros(FT, total_size) |> array_type(model.arch)
push!(trs, tr)
end
tr_d1 = zeros(total_size) |> array_type(model.arch)
tr_d2 = zeros(total_size) |> array_type(model.arch)
tr_d1 = zeros(FT, total_size) |> array_type(model.arch)
tr_d2 = zeros(FT, total_size) |> array_type(model.arch)
tr_default = (PAR = tr_d1, T = tr_d2)

diag_tr = NamedTuple{tracer}(trs)
Expand All @@ -50,14 +51,14 @@ function PlanktonDiagnostics(model; tracer=(),
for j in 1:Nsp
procs_sp = []
for k in 1:nproc
proc = zeros(total_size) |> array_type(model.arch)
proc = zeros(FT, total_size) |> array_type(model.arch)
push!(procs_sp, proc)
end
diag_proc = NamedTuple{plankton}(procs_sp)

procs_sp_d = []
for l in 1:4
proc = zeros(total_size) |> array_type(model.arch)
proc = zeros(FT, total_size) |> array_type(model.arch)
push!(procs_sp_d, proc)
end
diag_proc_default = NamedTuple{(:num, :graz, :mort, :dvid)}(procs_sp_d)
Expand Down
13 changes: 7 additions & 6 deletions src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@ include("halo_regions.jl")
include("boundary_conditions.jl")
include("apply_bcs.jl")

struct Field
data::AbstractArray{Float64,3}
struct Field{FT}
data::AbstractArray{FT,3}
bc::BoundaryConditions
end

"""
Field(arch::Architecture, grid::AbstractGrid; bcs = default_bcs())
Field(arch::Architecture, grid::AbstractGrid, FT::DataType; bcs = default_bcs())
Construct a `Field` on `grid` with data and boundary conditions on architecture `arch`
with DataType `FT`.
"""
function Field(arch::Architecture, grid::AbstractGrid; bcs = default_bcs())
function Field(arch::Architecture, grid::AbstractGrid, FT::DataType; bcs = default_bcs())
total_size = (grid.Nx+grid.Hx*2, grid.Ny+grid.Hy*2, grid.Nz+grid.Hz*2)
data = zeros(total_size) |> array_type(arch)
return Field(data,bcs)
data = zeros(FT, total_size) |> array_type(arch)
return Field{FT}(data,bcs)
end

@inline interior(c, grid) = c[grid.Hx+1:grid.Hx+grid.Nx, grid.Hy+1:grid.Hy+grid.Ny, grid.Hz+1:grid.Hz+grid.Nz]
Expand Down
31 changes: 22 additions & 9 deletions src/Fields/boundary_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,48 @@ function default_bcs()
end

"""
set_bc!(model, tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray})
Set the boundary condition of `tracer` on `pos` with `bc_value`.
set_bc!(model; tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray})
Set the boundary condition of `tracer` on `pos` with `bc_value` of DataType `FT`.

Keyword Arguments
=================
- `tracer`: the tracer of which the boundary condition will be set.
- `pos`: the position of the bounday condition to be set, e.g., `:east`, `:top` etc.
- `bc_value`: the value that will be used to set the boundary condition.
"""
function set_bc!(model, tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray})
function set_bc!(model; tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray})
@assert tracer in nut_names

FT = model.FT
bc_value_d = bc_value
if isa(bc_value, AbstractArray)
bc_value_d = bc_value |> array_type(model.arch)
bc_value_d = FT.(bc_value) |> array_type(model.arch)
end
setproperty!(model.nutrients[tracer].bc, pos, bc_value_d)
return nothing
end

# get boundary condition at each grid point
@inline getbc(bc::Number, i, j, t) = bc
@inline getbc(bc::AbstractArray{Float64,2}, i, j, t) = bc[i,j]
@inline getbc(bc::AbstractArray{Float64,3}, i, j, t) = bc[i,j,t]
@inline function getbc(bc::Union{Number, AbstractArray}, i, j, t)
if typeof(bc) <: Number
return bc
elseif typeof(bc) <: AbstractArray{eltype(bc),2}
return bc[i,j]
elseif typeof(bc) <: AbstractArray{eltype(bc),3}
return bc[i,j,t]
end
end

# validate boundary conditions, check if the grid information is compatible with nutrient field
function validate_bc(bc, bc_size, nΔT)
if typeof(bc) <: AbstractArray{Float64,2}
if typeof(bc) <: AbstractArray{eltype(bc),2}
if size(bc) == bc_size
return nothing
else
throw(ArgumentError("BC west: grid mismatch, size(bc) must equal to $(bc_size) for a constant flux boundary condition."))
end
end
if typeof(bc) <: AbstractArray{Float64,3}
if typeof(bc) <: AbstractArray{eltype(bc),3}
if size(bc) == (bc_size..., nΔT)
return nothing
else
Expand Down
58 changes: 29 additions & 29 deletions src/Fields/halo_regions.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
##### fill halo points based on topology
@inline fill_halo_west!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[1:H, :, :] = c[N+1:N+H, :, :]
@inline fill_halo_south!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, 1:H, :] = c[:, N+1:N+H, :]
@inline fill_halo_top!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, :, 1:H] = c[:, :, N+1:N+H]
@inline fill_halo_west!(c, H::Int, N::Int, ::Periodic) = @views @. c[1:H, :, :] = c[N+1:N+H, :, :]
@inline fill_halo_south!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, 1:H, :] = c[:, N+1:N+H, :]
@inline fill_halo_top!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, :, 1:H] = c[:, :, N+1:N+H]

@inline fill_halo_east!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[N+H+1:N+2H, :, :] = c[1+H:2H, :, :]
@inline fill_halo_north!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, N+H+1:N+2H, :] = c[:, 1+H:2H, :]
@inline fill_halo_bottom!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, 1+H:2H]
@inline fill_halo_east!(c, H::Int, N::Int, ::Periodic) = @views @. c[N+H+1:N+2H, :, :] = c[1+H:2H, :, :]
@inline fill_halo_north!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, N+H+1:N+2H, :] = c[:, 1+H:2H, :]
@inline fill_halo_bottom!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, 1+H:2H]

@inline fill_halo_west!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[1:H, :, :] = c[H+1:H+1, :, :]
@inline fill_halo_south!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, 1:H, :] = c[:, H+1:H+1, :]
@inline fill_halo_top!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, 1:H] = c[:, :, H+1:H+1]
@inline fill_halo_west!(c, H::Int, N::Int, ::Bounded) = @views @. c[1:H, :, :] = c[H+1:H+1, :, :]
@inline fill_halo_south!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, 1:H, :] = c[:, H+1:H+1, :]
@inline fill_halo_top!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, 1:H] = c[:, :, H+1:H+1]

@inline fill_halo_east!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = c[N+H:N+H, :, :]
@inline fill_halo_north!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = c[:, N+H:N+H, :]
@inline fill_halo_bottom!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, N+H:N+H]
@inline fill_halo_east!(c, H::Int, N::Int, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = c[N+H:N+H, :, :]
@inline fill_halo_north!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = c[:, N+H:N+H, :]
@inline fill_halo_bottom!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, N+H:N+H]

@inline fill_halo_east_vel!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[N+H+2:N+2H, :, :] = c[N+H+1:N+H+1, :, :]
@inline fill_halo_north_vel!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, N+H+2:N+2H, :] = c[:, N+H+1:N+H+1, :]
@inline fill_halo_bottom_vel!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, N+H+2:N+2H] = c[:, :, N+H+1:N+H+1]
@inline fill_halo_east_vel!(c, H::Int, N::Int, ::Bounded) = @views @. c[N+H+2:N+2H, :, :] = c[N+H+1:N+H+1, :, :]
@inline fill_halo_north_vel!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, N+H+2:N+2H, :] = c[:, N+H+1:N+H+1, :]
@inline fill_halo_bottom_vel!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, N+H+2:N+2H] = c[:, :, N+H+1:N+H+1]

@inline fill_halo_east_Gc!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = 0.0
@inline fill_halo_north_Gc!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = 0.0
@inline fill_halo_bottom_Gc!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = 0.0
@inline fill_halo_east_Gc!(c, H::Int, N::Int, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = 0.0
@inline fill_halo_north_Gc!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = 0.0
@inline fill_halo_bottom_Gc!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = 0.0

fill_halo_east_vel!(c, H::Int64, N::Int64, TX::Periodic) = fill_halo_east!(c, H, N, TX)
fill_halo_north_vel!(c, H::Int64, N::Int64, TY::Periodic) = fill_halo_north!(c, H, N, TY)
fill_halo_bottom_vel!(c, H::Int64, N::Int64, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ)
fill_halo_east_vel!(c, H::Int, N::Int, TX::Periodic) = fill_halo_east!(c, H, N, TX)
fill_halo_north_vel!(c, H::Int, N::Int, TY::Periodic) = fill_halo_north!(c, H, N, TY)
fill_halo_bottom_vel!(c, H::Int, N::Int, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ)

fill_halo_east_Gc!(c, H::Int64, N::Int64, TX::Periodic) = fill_halo_east!(c, H, N, TX)
fill_halo_north_Gc!(c, H::Int64, N::Int64, TY::Periodic) = fill_halo_north!(c, H, N, TY)
fill_halo_bottom_Gc!(c, H::Int64, N::Int64, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ)
fill_halo_east_Gc!(c, H::Int, N::Int, TX::Periodic) = fill_halo_east!(c, H, N, TX)
fill_halo_north_Gc!(c, H::Int, N::Int, TY::Periodic) = fill_halo_north!(c, H, N, TY)
fill_halo_bottom_Gc!(c, H::Int, N::Int, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ)

@inline function fill_halo_nut!(nuts::NamedTuple, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ}
@inline function fill_halo_nut!(nuts::NamedTuple, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ}
for nut in nuts
fill_halo_west!(nut.data, g.Hx, g.Nx, TX())
fill_halo_east!(nut.data, g.Hx, g.Nx, TX())
Expand All @@ -43,7 +43,7 @@ fill_halo_bottom_Gc!(c, H::Int64, N::Int64, TZ::Periodic) = fill_halo_bottom!(c,
return nothing
end

@inline function fill_halo_Gcs!(nuts::NamedTuple, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ}
@inline function fill_halo_Gcs!(nuts::NamedTuple, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ}
for nut in nuts
fill_halo_east_Gc!(nut.data, g.Hx, g.Nx, TX())
fill_halo_north_Gc!(nut.data, g.Hy, g.Ny, TY())
Expand All @@ -52,7 +52,7 @@ end
return nothing
end

@inline function fill_halo_u!(u, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ}
@inline function fill_halo_u!(u, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ}
fill_halo_east_vel!(u, g.Hx, g.Nx, TX())

fill_halo_west!(u, g.Hx, g.Nx, TX())
Expand All @@ -62,7 +62,7 @@ end
fill_halo_bottom!(u, g.Hz, g.Nz, TZ())
end

@inline function fill_halo_v!(v, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ}
@inline function fill_halo_v!(v, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ}
fill_halo_north_vel!(v, g.Hy, g.Ny, TY())

fill_halo_west!(v, g.Hx, g.Nx, TX())
Expand All @@ -72,7 +72,7 @@ end
fill_halo_bottom!(v, g.Hz, g.Nz, TZ())
end

@inline function fill_halo_w!(w, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ}
@inline function fill_halo_w!(w, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ}
fill_halo_bottom_vel!(w, g.Hz, g.Nz, TZ())

fill_halo_west!(w, g.Hx, g.Nx, TX())
Expand Down
6 changes: 3 additions & 3 deletions src/Grids/Grids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ using Adapt
using PlanktonIndividuals.Architectures

"""
AbstractGrid{TX, TY, TZ}
Abstract type for grids with elements of type `Float64` and topology `{TX, TY, TZ}`.
AbstractGrid{FT, TX, TY, TZ}
Abstract type for grids with elements of type `FT` and topology `{TX, TY, TZ}`.
"""
abstract type AbstractGrid{TX, TY, YZ} end
abstract type AbstractGrid{FT, TX, TY, YZ} end

"""
AbstractTopology
Expand Down
Loading
Loading