From a9e56fcfa0b3263022ea5c58a3a067793a6db9bf Mon Sep 17 00:00:00 2001 From: Zhen Wu Date: Wed, 8 May 2024 22:50:18 -0400 Subject: [PATCH] support both `Float32` and `Float64` --- src/Biogeochemistry/nutrient_fields.jl | 33 ++-- src/Diagnostics/diagnostics_struct.jl | 15 +- src/Fields/Fields.jl | 13 +- src/Fields/boundary_conditions.jl | 31 +++- src/Fields/halo_regions.jl | 58 +++--- src/Grids/Grids.jl | 6 +- src/Grids/lat_lon_grid.jl | 174 +++++++++--------- src/Grids/rectilinear_grid.jl | 65 +++---- src/Grids/utils.jl | 23 ++- src/Model/models.jl | 21 ++- src/Model/timestepper.jl | 37 ++-- .../Advection/plankton_advection_kernels.jl | 2 +- .../CarbonMode/plankton_generation.jl | 20 +- .../MacroMolecularMode/plankton_generation.jl | 26 +-- src/Plankton/Plankton.jl | 16 +- src/Plankton/QuotaMode/plankton_generation.jl | 24 +-- src/Simulation/simulations.jl | 20 +- src/Simulation/utils.jl | 10 +- test/field_test.jl | 10 +- test/output_test.jl | 4 +- 20 files changed, 320 insertions(+), 288 deletions(-) diff --git a/src/Biogeochemistry/nutrient_fields.jl b/src/Biogeochemistry/nutrient_fields.jl index 6bd3c709..f4ad21c8 100644 --- a/src/Biogeochemistry/nutrient_fields.jl +++ b/src/Biogeochemistry/nutrient_fields.jl @@ -1,9 +1,9 @@ -function nutrients_init(arch, g) - fields = (Field(arch, g), Field(arch, g), - Field(arch, g), Field(arch, g), - Field(arch, g), Field(arch, g), - Field(arch, g), Field(arch, g), - Field(arch, g), Field(arch, g)) +function nutrients_init(arch, g, FT = Float32) + fields = (Field(arch, g, FT), Field(arch, g, FT), + Field(arch, g, FT), Field(arch, g, FT), + Field(arch, g, FT), Field(arch, g, FT), + Field(arch, g, FT), Field(arch, g, FT), + Field(arch, g, FT), Field(arch, g, FT)) nut = NamedTuple{nut_names}(fields) return nut @@ -20,19 +20,22 @@ function default_nut_init() end """ - generate_nutrients(arch, grid, source) + generate_nutrients(arch, grid, source, FT) Set up initial nutrient fields according to `grid`. -Keyword Arguments +Arguments ================= - `arch`: `CPU()` or `GPU()`. The computer architecture used to time-step `model`. - `grid`: The resolution and discrete geometry on which nutrient fields are solved. -- `source`: A `NamedTuple` containing 10 numbers each of which is the uniform initial condition of one tracer, - or a `Dict` containing the file paths pointing to the files of nutrient initial conditions. +- `source`: A `NamedTuple` containing 10 numbers each of which is the uniform initial + condition of one tracer, or a `Dict` containing the file paths pointing to + the files of nutrient initial conditions. +- `FT`: Floating point data type. Default: `Float32`. """ -function generate_nutrients(arch, g, source::Union{Dict,NamedTuple}) +function generate_nutrients(arch::Architecture, g::AbstractGrid, + source::Union{Dict,NamedTuple}, FT::DataType) total_size = (g.Nx+g.Hx*2, g.Ny+g.Hy*2, g.Nz+g.Hz*2) - nut = nutrients_init(arch, g) + nut = nutrients_init(arch, g, FT) pathkeys = collect(keys(source)) if typeof(source) <: NamedTuple @@ -58,9 +61,9 @@ function generate_nutrients(arch, g, source::Union{Dict,NamedTuple}) if source.initial_condition[name] < 0.0 throw(ArgumentError("NUT_INIT: The initial condition should be none-negetive.")) end - lower = 1.0 - source.rand_noise[name] - upper = 1.0 + source.rand_noise[name] - nut[name].data .= fill(source.initial_condition[name],total_size) .* rand(lower:1e-4:upper, total_size) |> array_type(arch) + lower = FT(1.0 - source.rand_noise[name]) + upper = FT(1.0 + source.rand_noise[name]) + nut[name].data .= fill(FT(source.initial_condition[name]),total_size) .* rand(lower:1e-4:upper, total_size) |> array_type(arch) end end diff --git a/src/Diagnostics/diagnostics_struct.jl b/src/Diagnostics/diagnostics_struct.jl index ca212ed1..fcc938d5 100644 --- a/src/Diagnostics/diagnostics_struct.jl +++ b/src/Diagnostics/diagnostics_struct.jl @@ -1,7 +1,7 @@ mutable struct PlanktonDiagnostics plankton::NamedTuple # for each species tracer::NamedTuple # for tracers - iteration_interval::Int64 # time interval that the diagnostics is time averaged + iteration_interval::Int # time interval that the diagnostics is time averaged end """ @@ -19,7 +19,7 @@ Keyword Arguments (Optional) """ function PlanktonDiagnostics(model; tracer=(), plankton=(:num, :graz, :mort, :dvid), - iteration_interval::Int64 = 1) + iteration_interval::Int = 1) @assert isa(tracer, Tuple) @assert isa(plankton, Tuple) @@ -30,15 +30,16 @@ function PlanktonDiagnostics(model; tracer=(), nproc = length(plankton) trs = [] procs = [] + FT = model.FT total_size = (model.grid.Nx+model.grid.Hx*2, model.grid.Ny+model.grid.Hy*2, model.grid.Nz+model.grid.Hz*2) for i in 1:ntr - tr = zeros(total_size) |> array_type(model.arch) + tr = zeros(FT, total_size) |> array_type(model.arch) push!(trs, tr) end - tr_d1 = zeros(total_size) |> array_type(model.arch) - tr_d2 = zeros(total_size) |> array_type(model.arch) + tr_d1 = zeros(FT, total_size) |> array_type(model.arch) + tr_d2 = zeros(FT, total_size) |> array_type(model.arch) tr_default = (PAR = tr_d1, T = tr_d2) diag_tr = NamedTuple{tracer}(trs) @@ -50,14 +51,14 @@ function PlanktonDiagnostics(model; tracer=(), for j in 1:Nsp procs_sp = [] for k in 1:nproc - proc = zeros(total_size) |> array_type(model.arch) + proc = zeros(FT, total_size) |> array_type(model.arch) push!(procs_sp, proc) end diag_proc = NamedTuple{plankton}(procs_sp) procs_sp_d = [] for l in 1:4 - proc = zeros(total_size) |> array_type(model.arch) + proc = zeros(FT, total_size) |> array_type(model.arch) push!(procs_sp_d, proc) end diag_proc_default = NamedTuple{(:num, :graz, :mort, :dvid)}(procs_sp_d) diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index 8db22cbb..6ceb1221 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -19,19 +19,20 @@ include("halo_regions.jl") include("boundary_conditions.jl") include("apply_bcs.jl") -struct Field - data::AbstractArray{Float64,3} +struct Field{FT} + data::AbstractArray{FT,3} bc::BoundaryConditions end """ - Field(arch::Architecture, grid::AbstractGrid; bcs = default_bcs()) + Field(arch::Architecture, grid::AbstractGrid, FT::DataType; bcs = default_bcs()) Construct a `Field` on `grid` with data and boundary conditions on architecture `arch` +with DataType `FT`. """ -function Field(arch::Architecture, grid::AbstractGrid; bcs = default_bcs()) +function Field(arch::Architecture, grid::AbstractGrid, FT::DataType; bcs = default_bcs()) total_size = (grid.Nx+grid.Hx*2, grid.Ny+grid.Hy*2, grid.Nz+grid.Hz*2) - data = zeros(total_size) |> array_type(arch) - return Field(data,bcs) + data = zeros(FT, total_size) |> array_type(arch) + return Field{FT}(data,bcs) end @inline interior(c, grid) = c[grid.Hx+1:grid.Hx+grid.Nx, grid.Hy+1:grid.Hy+grid.Ny, grid.Hz+1:grid.Hz+grid.Nz] diff --git a/src/Fields/boundary_conditions.jl b/src/Fields/boundary_conditions.jl index 59de3df1..2c8ad663 100644 --- a/src/Fields/boundary_conditions.jl +++ b/src/Fields/boundary_conditions.jl @@ -19,35 +19,48 @@ function default_bcs() end """ - set_bc!(model, tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray}) -Set the boundary condition of `tracer` on `pos` with `bc_value`. + set_bc!(model; tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray}) +Set the boundary condition of `tracer` on `pos` with `bc_value` of DataType `FT`. + +Keyword Arguments +================= +- `tracer`: the tracer of which the boundary condition will be set. +- `pos`: the position of the bounday condition to be set, e.g., `:east`, `:top` etc. +- `bc_value`: the value that will be used to set the boundary condition. """ -function set_bc!(model, tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray}) +function set_bc!(model; tracer::Symbol, pos::Symbol, bc_value::Union{Number, AbstractArray}) @assert tracer in nut_names + FT = model.FT bc_value_d = bc_value if isa(bc_value, AbstractArray) - bc_value_d = bc_value |> array_type(model.arch) + bc_value_d = FT.(bc_value) |> array_type(model.arch) end setproperty!(model.nutrients[tracer].bc, pos, bc_value_d) return nothing end # get boundary condition at each grid point -@inline getbc(bc::Number, i, j, t) = bc -@inline getbc(bc::AbstractArray{Float64,2}, i, j, t) = bc[i,j] -@inline getbc(bc::AbstractArray{Float64,3}, i, j, t) = bc[i,j,t] +@inline function getbc(bc::Union{Number, AbstractArray}, i, j, t) + if typeof(bc) <: Number + return bc + elseif typeof(bc) <: AbstractArray{eltype(bc),2} + return bc[i,j] + elseif typeof(bc) <: AbstractArray{eltype(bc),3} + return bc[i,j,t] + end +end # validate boundary conditions, check if the grid information is compatible with nutrient field function validate_bc(bc, bc_size, nΔT) - if typeof(bc) <: AbstractArray{Float64,2} + if typeof(bc) <: AbstractArray{eltype(bc),2} if size(bc) == bc_size return nothing else throw(ArgumentError("BC west: grid mismatch, size(bc) must equal to $(bc_size) for a constant flux boundary condition.")) end end - if typeof(bc) <: AbstractArray{Float64,3} + if typeof(bc) <: AbstractArray{eltype(bc),3} if size(bc) == (bc_size..., nΔT) return nothing else diff --git a/src/Fields/halo_regions.jl b/src/Fields/halo_regions.jl index 696f8078..c86b98a8 100644 --- a/src/Fields/halo_regions.jl +++ b/src/Fields/halo_regions.jl @@ -1,37 +1,37 @@ ##### fill halo points based on topology -@inline fill_halo_west!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[1:H, :, :] = c[N+1:N+H, :, :] -@inline fill_halo_south!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, 1:H, :] = c[:, N+1:N+H, :] -@inline fill_halo_top!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, :, 1:H] = c[:, :, N+1:N+H] +@inline fill_halo_west!(c, H::Int, N::Int, ::Periodic) = @views @. c[1:H, :, :] = c[N+1:N+H, :, :] +@inline fill_halo_south!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, 1:H, :] = c[:, N+1:N+H, :] +@inline fill_halo_top!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, :, 1:H] = c[:, :, N+1:N+H] -@inline fill_halo_east!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[N+H+1:N+2H, :, :] = c[1+H:2H, :, :] -@inline fill_halo_north!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, N+H+1:N+2H, :] = c[:, 1+H:2H, :] -@inline fill_halo_bottom!(c, H::Int64, N::Int64, ::Periodic) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, 1+H:2H] +@inline fill_halo_east!(c, H::Int, N::Int, ::Periodic) = @views @. c[N+H+1:N+2H, :, :] = c[1+H:2H, :, :] +@inline fill_halo_north!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, N+H+1:N+2H, :] = c[:, 1+H:2H, :] +@inline fill_halo_bottom!(c, H::Int, N::Int, ::Periodic) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, 1+H:2H] -@inline fill_halo_west!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[1:H, :, :] = c[H+1:H+1, :, :] -@inline fill_halo_south!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, 1:H, :] = c[:, H+1:H+1, :] -@inline fill_halo_top!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, 1:H] = c[:, :, H+1:H+1] +@inline fill_halo_west!(c, H::Int, N::Int, ::Bounded) = @views @. c[1:H, :, :] = c[H+1:H+1, :, :] +@inline fill_halo_south!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, 1:H, :] = c[:, H+1:H+1, :] +@inline fill_halo_top!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, 1:H] = c[:, :, H+1:H+1] -@inline fill_halo_east!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = c[N+H:N+H, :, :] -@inline fill_halo_north!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = c[:, N+H:N+H, :] -@inline fill_halo_bottom!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, N+H:N+H] +@inline fill_halo_east!(c, H::Int, N::Int, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = c[N+H:N+H, :, :] +@inline fill_halo_north!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = c[:, N+H:N+H, :] +@inline fill_halo_bottom!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = c[:, :, N+H:N+H] -@inline fill_halo_east_vel!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[N+H+2:N+2H, :, :] = c[N+H+1:N+H+1, :, :] -@inline fill_halo_north_vel!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, N+H+2:N+2H, :] = c[:, N+H+1:N+H+1, :] -@inline fill_halo_bottom_vel!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, N+H+2:N+2H] = c[:, :, N+H+1:N+H+1] +@inline fill_halo_east_vel!(c, H::Int, N::Int, ::Bounded) = @views @. c[N+H+2:N+2H, :, :] = c[N+H+1:N+H+1, :, :] +@inline fill_halo_north_vel!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, N+H+2:N+2H, :] = c[:, N+H+1:N+H+1, :] +@inline fill_halo_bottom_vel!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, N+H+2:N+2H] = c[:, :, N+H+1:N+H+1] -@inline fill_halo_east_Gc!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = 0.0 -@inline fill_halo_north_Gc!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = 0.0 -@inline fill_halo_bottom_Gc!(c, H::Int64, N::Int64, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = 0.0 +@inline fill_halo_east_Gc!(c, H::Int, N::Int, ::Bounded) = @views @. c[N+H+1:N+2H, :, :] = 0.0 +@inline fill_halo_north_Gc!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, N+H+1:N+2H, :] = 0.0 +@inline fill_halo_bottom_Gc!(c, H::Int, N::Int, ::Bounded) = @views @. c[:, :, N+H+1:N+2H] = 0.0 -fill_halo_east_vel!(c, H::Int64, N::Int64, TX::Periodic) = fill_halo_east!(c, H, N, TX) -fill_halo_north_vel!(c, H::Int64, N::Int64, TY::Periodic) = fill_halo_north!(c, H, N, TY) -fill_halo_bottom_vel!(c, H::Int64, N::Int64, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ) +fill_halo_east_vel!(c, H::Int, N::Int, TX::Periodic) = fill_halo_east!(c, H, N, TX) +fill_halo_north_vel!(c, H::Int, N::Int, TY::Periodic) = fill_halo_north!(c, H, N, TY) +fill_halo_bottom_vel!(c, H::Int, N::Int, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ) -fill_halo_east_Gc!(c, H::Int64, N::Int64, TX::Periodic) = fill_halo_east!(c, H, N, TX) -fill_halo_north_Gc!(c, H::Int64, N::Int64, TY::Periodic) = fill_halo_north!(c, H, N, TY) -fill_halo_bottom_Gc!(c, H::Int64, N::Int64, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ) +fill_halo_east_Gc!(c, H::Int, N::Int, TX::Periodic) = fill_halo_east!(c, H, N, TX) +fill_halo_north_Gc!(c, H::Int, N::Int, TY::Periodic) = fill_halo_north!(c, H, N, TY) +fill_halo_bottom_Gc!(c, H::Int, N::Int, TZ::Periodic) = fill_halo_bottom!(c, H, N, TZ) -@inline function fill_halo_nut!(nuts::NamedTuple, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ} +@inline function fill_halo_nut!(nuts::NamedTuple, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} for nut in nuts fill_halo_west!(nut.data, g.Hx, g.Nx, TX()) fill_halo_east!(nut.data, g.Hx, g.Nx, TX()) @@ -43,7 +43,7 @@ fill_halo_bottom_Gc!(c, H::Int64, N::Int64, TZ::Periodic) = fill_halo_bottom!(c, return nothing end -@inline function fill_halo_Gcs!(nuts::NamedTuple, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ} +@inline function fill_halo_Gcs!(nuts::NamedTuple, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} for nut in nuts fill_halo_east_Gc!(nut.data, g.Hx, g.Nx, TX()) fill_halo_north_Gc!(nut.data, g.Hy, g.Ny, TY()) @@ -52,7 +52,7 @@ end return nothing end -@inline function fill_halo_u!(u, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ} +@inline function fill_halo_u!(u, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} fill_halo_east_vel!(u, g.Hx, g.Nx, TX()) fill_halo_west!(u, g.Hx, g.Nx, TX()) @@ -62,7 +62,7 @@ end fill_halo_bottom!(u, g.Hz, g.Nz, TZ()) end -@inline function fill_halo_v!(v, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ} +@inline function fill_halo_v!(v, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} fill_halo_north_vel!(v, g.Hy, g.Ny, TY()) fill_halo_west!(v, g.Hx, g.Nx, TX()) @@ -72,7 +72,7 @@ end fill_halo_bottom!(v, g.Hz, g.Nz, TZ()) end -@inline function fill_halo_w!(w, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ} +@inline function fill_halo_w!(w, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} fill_halo_bottom_vel!(w, g.Hz, g.Nz, TZ()) fill_halo_west!(w, g.Hx, g.Nx, TX()) diff --git a/src/Grids/Grids.jl b/src/Grids/Grids.jl index 685a40fc..f313a8c3 100644 --- a/src/Grids/Grids.jl +++ b/src/Grids/Grids.jl @@ -14,10 +14,10 @@ using Adapt using PlanktonIndividuals.Architectures """ - AbstractGrid{TX, TY, TZ} -Abstract type for grids with elements of type `Float64` and topology `{TX, TY, TZ}`. + AbstractGrid{FT, TX, TY, TZ} +Abstract type for grids with elements of type `FT` and topology `{TX, TY, TZ}`. """ -abstract type AbstractGrid{TX, TY, YZ} end +abstract type AbstractGrid{FT, TX, TY, YZ} end """ AbstractTopology diff --git a/src/Grids/lat_lon_grid.jl b/src/Grids/lat_lon_grid.jl index e0863a8f..ca5600f8 100644 --- a/src/Grids/lat_lon_grid.jl +++ b/src/Grids/lat_lon_grid.jl @@ -1,30 +1,30 @@ -struct LatLonGrid{TX, TY, TZ, R, A1, A2, A3} <: AbstractGrid{TX, TY, TZ} +struct LatLonGrid{FT, TX, TY, TZ} <: AbstractGrid{FT, TX, TY, TZ} # corrdinates at cell centers, unit: degree - xC::R - yC::R + xC::Vector{FT} + yC::Vector{FT} # corrdinates at cell centers, unit: meter - zC::A1 + zC::Vector{FT} # corrdinates at cell faces, unit: degree - xF::R - yF::R + xF::Vector{FT} + yF::Vector{FT} # corrdinates at cell faces, unit: meter - zF::A1 + zF::Vector{FT} # grid spacing, unit: degree - Δx::Float64 - Δy::Float64 + Δx::FT + Δy::FT # grid spacing from center to center, unit: meter - dxC::A2 - dyC::A2 - dzC::A1 + dxC::AbstractArray{FT,2} + dyC::AbstractArray{FT,2} + dzC::Vector{FT} # grid spacing from face to face, unit: meter - dxF::A2 - dyF::A2 - dzF::A1 + dxF::AbstractArray{FT,2} + dyF::AbstractArray{FT,2} + dzF::Vector{FT} # areas and volume, unit: m² or m³ - Ax::A3 - Ay::A3 - Az::A2 - Vol::A3 + Ax::AbstractArray{FT,3} + Ay::AbstractArray{FT,3} + Az::AbstractArray{FT,2} + Vol::AbstractArray{FT,3} # number of grid points Nx::Int Ny::Int @@ -34,11 +34,12 @@ struct LatLonGrid{TX, TY, TZ, R, A1, A2, A3} <: AbstractGrid{TX, TY, TZ} Hy::Int Hz::Int # landmask to indicate where is the land - landmask::A3 + landmask::AbstractArray{FT,3} end """ LatLonGrid(;size, lat, lon, z, + FT = Float32, radius = 6370.0e3, landmask = nothing, halo = (2, 2, 2)) @@ -58,6 +59,7 @@ Keyword Arguments (Required) Keyword Arguments (Optional) ============================ +- `FT`: Floating point data type. Default: `Float32`. - `radius` : Specify the radius of the Earth used in the model, 6370.0e3 meters by default. - `landmask` : a 3-dimentional array to indicate where the land is. - `halo` : A tuple of integers that specifies the size of the halo region of cells @@ -66,6 +68,7 @@ Keyword Arguments (Optional) At least 2 halo points are needed for DST3FL advection scheme. """ function LatLonGrid(;size, lat, lon, z, + FT = Float32, radius = 6370.0e3, landmask = nothing, halo = (2, 2, 2)) @@ -76,7 +79,7 @@ function LatLonGrid(;size, lat, lon, z, if isa(z, Tuple{<:Number, <:Number}) z₁, z₂ = z - z = collect(range(z₁, z₂, length = Nz+1)) + z = collect(FT, range(z₁, z₂, length = Nz+1)) elseif isa(z, AbstractVector) z₁, z₂ = z[1], z[end] else @@ -95,27 +98,27 @@ function LatLonGrid(;size, lat, lon, z, Δx = (lon₂ - lon₁) / Nx Δy = (lat₂ - lat₁) / Ny - xF = range(lon₁ - Hx * Δx, lon₁ + (Nx + Hx - 1) * Δx, length = Nx + 2 * Hx) - yF = range(lat₁ - Hy * Δy, lat₁ + (Ny + Hy - 1) * Δy, length = Ny + 2 * Hy) + xF = collect(FT, range(lon₁ - Hx * Δx, lon₁ + (Nx + Hx - 1) * Δx, length = Nx + 2 * Hx)) + yF = collect(FT, range(lat₁ - Hy * Δy, lat₁ + (Ny + Hy - 1) * Δy, length = Ny + 2 * Hy)) - xC = range(lon₁ + (0.5 - Hx) * Δx, lon₁ + (Nx + Hx - 0.5) * Δx, length = Nx + 2 * Hx) - yC = range(lat₁ + (0.5 - Hy) * Δy, lat₁ + (Ny + Hy - 0.5) * Δy, length = Ny + 2 * Hy) + xC = collect(FT, range(lon₁ + (0.5 - Hx) * Δx, lon₁ + (Nx + Hx - 0.5) * Δx, length = Nx + 2 * Hx)) + yC = collect(FT, range(lat₁ + (0.5 - Hy) * Δy, lat₁ + (Ny + Hy - 0.5) * Δy, length = Ny + 2 * Hy)) # inclue halo points - zF = zeros(Nz+2Hz) - zC = zeros(Nz+2Hz) - dzF = zeros(Nz+2Hz) - dzC = zeros(Nz+2Hz) - dxC = zeros(Nx+2*Hx, Ny+2*Hy) - dyC = zeros(Nx+2*Hx, Ny+2*Hy) - dxF = zeros(Nx+2*Hx, Ny+2*Hy) - dyF = zeros(Nx+2*Hx, Ny+2*Hy) - Ax = zeros(Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) - Ay = zeros(Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) - Az = zeros(Nx+2*Hx, Ny+2*Hy) - Vol = zeros(Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) - - zF[1+Hz:Nz+Hz+1] .= Float64.(z) + zF = zeros(FT, Nz+2Hz) + zC = zeros(FT, Nz+2Hz) + dzF = zeros(FT, Nz+2Hz) + dzC = zeros(FT, Nz+2Hz) + dxC = zeros(FT, Nx+2*Hx, Ny+2*Hy) + dyC = zeros(FT, Nx+2*Hx, Ny+2*Hy) + dxF = zeros(FT, Nx+2*Hx, Ny+2*Hy) + dyF = zeros(FT, Nx+2*Hx, Ny+2*Hy) + Ax = zeros(FT, Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) + Ay = zeros(FT, Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) + Az = zeros(FT, Nx+2*Hx, Ny+2*Hy) + Vol = zeros(FT, Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) + + zF[1+Hz:Nz+Hz+1] .= FT.(z) zC[1+Hz:Nz+Hz] .= (zF[1+Hz:Nz+Hz] .+ zF[2+Hz:Nz+Hz+1]) ./ 2 dzF[1+Hz:Nz+Hz] .= zF[1+Hz:Nz+Hz] .- zF[2+Hz:Nz+Hz+1] dzC[1+Hz:Nz+Hz-1] .= zC[1+Hz:Nz+Hz-1] .- zC[2+Hz:Nz+Hz] @@ -161,16 +164,17 @@ function LatLonGrid(;size, lat, lon, z, end end - landmask = landmask_validation(landmask, Nx, Ny, Nz, Hx, Hy, Hz, TX, TY) + landmask = landmask_validation(landmask, Nx, Ny, Nz, Hx, Hy, Hz, FT, TX, TY) - return LatLonGrid{TX, TY, TZ, typeof(xF), typeof(zF), typeof(dxC), typeof(Vol)}( + return LatLonGrid{FT, TX, TY, TZ}( xC, yC, zC, xF, yF, zF, Δx, Δy, dxC, dyC, dzC, dxF, dyF, dzF, Ax, Ay, Az, Vol, Nx, Ny, Nz, Hx, Hy, Hz, landmask) end """ LoadLatLonGrid(;grid_info, size, lat, lon, - landmask = nothing, - halo=(2,2,2)) + FT = Float32, + landmask = nothing, + halo=(2,2,2)) Creats a `LatLonGrid` struct with `size = (Nx, Ny, Nz)` grid points. Keyword Arguments (Required) @@ -186,13 +190,15 @@ Keyword Arguments (Required) Keyword Arguments (Optional) ============================ +- `FT`: Floating point data type. Default: `Float32`. - `landmask` : a 3-dimentional array to indicate where the land is. - `halo` : A tuple of integers that specifies the size of the halo region of cells surrounding the physical interior for each direction. `halo` is a 3-tuple no matter for 3D, 2D, or 1D model. At least 2 halo points are needed for DST3FL advection scheme. """ -function LoadLatLonGrid(;grid_info, size, lat, lon, landmask = nothing, halo=(2,2,2)) +function LoadLatLonGrid(;grid_info, size, lat, lon, + FT = Float32, landmask = nothing, halo=(2,2,2)) Nx, Ny, Nz = size Hx, Hy, Hz = halo lat₁, lat₂ = lat @@ -208,40 +214,40 @@ function LoadLatLonGrid(;grid_info, size, lat, lon, landmask = nothing, halo=(2, Δx = (lon₂ - lon₁) / Nx Δy = (lat₂ - lat₁) / Ny - xF = range(lon₁ - Hx * Δx, lon₁ + (Nx + Hx - 1) * Δx, length = Nx + 2 * Hx) - yF = range(lat₁ - Hy * Δy, lat₁ + (Ny + Hy - 1) * Δy, length = Ny + 2 * Hy) - - xC = range(lon₁ + (0.5 - Hx) * Δx, lon₁ + (Nx + Hx - 0.5) * Δx, length = Nx + 2 * Hx) - yC = range(lat₁ + (0.5 - Hy) * Δy, lat₁ + (Ny + Hy - 0.5) * Δy, length = Ny + 2 * Hy) - - zF = zeros(Nz+2Hz) - zC = zeros(Nz+2Hz) - dzF = zeros(Nz+2Hz) - dzC = zeros(Nz+2Hz) - dxC = zeros(Nx+2*Hx, Ny+2*Hy) - dyC = zeros(Nx+2*Hx, Ny+2*Hy) - dxF = zeros(Nx+2*Hx, Ny+2*Hy) - dyF = zeros(Nx+2*Hx, Ny+2*Hy) - Ax = zeros(Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) - Ay = zeros(Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) - Az = zeros(Nx+2*Hx, Ny+2*Hy) - Vol = zeros(Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) - hFW = ones(Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) - hFS = ones(Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) - hFC = ones(Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) - - zF[1+Hz:Nz+Hz+1] = Float64.(grid_info.RF) - zC[1+Hz:Nz+Hz] = Float64.(grid_info.RC) - dzF[1+Hz:Nz+Hz] = Float64.(grid_info.DRF) - dzC[1+Hz:Nz+Hz] = Float64.(grid_info.DRC); dzC[1+Hz] *= 2.0 # First laryer only has half of the grid - dxC[1+Hx:Nx+Hx, 1+Hy:Ny+Hy] = Float64.(grid_info.DXC) - dxF[1+Hx:Nx+Hx, 1+Hy:Ny+Hy] = Float64.(grid_info.DXG) - dyC[1+Hx:Nx+Hx, 1+Hy:Ny+Hy] = Float64.(grid_info.DYC) - dyF[1+Hx:Nx+Hx, 1+Hy:Ny+Hy] = Float64.(grid_info.DYG) - Az[1+Hx:Nx+Hx, 1+Hy:Ny+Hy] = Float64.(grid_info.RAC) - hFW[1+Hx:Nx+Hx, 1+Hy:Ny+Hy, 1+Hz:Nz+Hz] = Float64.(grid_info.hFacW) - hFS[1+Hx:Nx+Hx, 1+Hy:Ny+Hy, 1+Hz:Nz+Hz] = Float64.(grid_info.hFacS) - hFC[1+Hx:Nx+Hx, 1+Hy:Ny+Hy, 1+Hz:Nz+Hz] = Float64.(grid_info.hFacC) + xF = collect(FT, range(lon₁ - Hx * Δx, lon₁ + (Nx + Hx - 1) * Δx, length = Nx + 2 * Hx)) + yF = collect(FT, range(lat₁ - Hy * Δy, lat₁ + (Ny + Hy - 1) * Δy, length = Ny + 2 * Hy)) + + xC = collect(FT, range(lon₁ + (0.5 - Hx) * Δx, lon₁ + (Nx + Hx - 0.5) * Δx, length = Nx + 2 * Hx)) + yC = collect(FT, range(lat₁ + (0.5 - Hy) * Δy, lat₁ + (Ny + Hy - 0.5) * Δy, length = Ny + 2 * Hy)) + + zF = zeros(FT, Nz+2Hz) + zC = zeros(FT, Nz+2Hz) + dzF = zeros(FT, Nz+2Hz) + dzC = zeros(FT, Nz+2Hz) + dxC = zeros(FT, Nx+2*Hx, Ny+2*Hy) + dyC = zeros(FT, Nx+2*Hx, Ny+2*Hy) + dxF = zeros(FT, Nx+2*Hx, Ny+2*Hy) + dyF = zeros(FT, Nx+2*Hx, Ny+2*Hy) + Ax = zeros(FT, Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) + Ay = zeros(FT, Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) + Az = zeros(FT, Nx+2*Hx, Ny+2*Hy) + Vol = zeros(FT, Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) + hFW = ones(FT, Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) + hFS = ones(FT, Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) + hFC = ones(FT, Nx+2*Hx, Ny+2*Hy, Nz+2*Hz) + + zF[1+Hz:Nz+Hz+1] = FT.(grid_info.RF) + zC[1+Hz:Nz+Hz] = FT.(grid_info.RC) + dzF[1+Hz:Nz+Hz] = FT.(grid_info.DRF) + dzC[1+Hz:Nz+Hz] = FT.(grid_info.DRC); dzC[1+Hz] *= 2.0 # First laryer only has half of the grid + dxC[1+Hx:Nx+Hx, 1+Hy:Ny+Hy] = FT.(grid_info.DXC) + dxF[1+Hx:Nx+Hx, 1+Hy:Ny+Hy] = FT.(grid_info.DXG) + dyC[1+Hx:Nx+Hx, 1+Hy:Ny+Hy] = FT.(grid_info.DYC) + dyF[1+Hx:Nx+Hx, 1+Hy:Ny+Hy] = FT.(grid_info.DYG) + Az[1+Hx:Nx+Hx, 1+Hy:Ny+Hy] = FT.(grid_info.RAC) + hFW[1+Hx:Nx+Hx, 1+Hy:Ny+Hy, 1+Hz:Nz+Hz] = FT.(grid_info.hFacW) + hFS[1+Hx:Nx+Hx, 1+Hy:Ny+Hy, 1+Hz:Nz+Hz] = FT.(grid_info.hFacS) + hFC[1+Hx:Nx+Hx, 1+Hy:Ny+Hy, 1+Hz:Nz+Hz] = FT.(grid_info.hFacC) ##### fill halos @views @. dzF[1:Hz] = dzF[Hz+1] @@ -296,19 +302,19 @@ function LoadLatLonGrid(;grid_info, size, lat, lon, landmask = nothing, halo=(2, end end - landmask = landmask_validation(landmask, Nx, Ny, Nz, Hx, Hy, Hz, TX, TY) + landmask = landmask_validation(landmask, Nx, Ny, Nz, Hx, Hy, Hz, FT, TX, TY) - return LatLonGrid{TX, TY, TZ, typeof(xF), typeof(zF), typeof(dxC), typeof(Vol)}( + return LatLonGrid{FT, TX, TY, TZ}( xC, yC, zC, xF, yF, zF, Δx, Δy, dxC, dyC, dzC, dxF, dyF, dzF, Ax, Ay, Az, Vol, Nx, Ny, Nz, Hx, Hy, Hz, landmask) end -function show(io::IO, g::LatLonGrid{TX, TY, TZ}) where {TX, TY, TZ} +function show(io::IO, g::LatLonGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} xL, xR = g.xF[g.Hx+1], g.xF[g.Hx+1+g.Nx] yL, yR = g.yF[g.Hy+1], g.yF[g.Hy+1+g.Ny] zL, zR = g.zF[g.Hz+1], g.zF[g.Hz+1+g.Nz] dzF_min = minimum(g.dzF) dzF_max = maximum(g.dzF) - print(io, "LatLonGrid{$TX, $TY, $TZ}\n", + print(io, "LatLonGrid{$FT, $TX, $TY, $TZ}\n", "domain: x ∈ [$xL, $xR], y ∈ [$yL, $yR], z ∈ [$zL, $zR]\n", "topology (Tx, Ty, Tz): ", (TX, TY, TZ), '\n', "resolution (Nx, Ny, Nz): ", (g.Nx, g.Ny, g.Nz), '\n', @@ -316,6 +322,6 @@ function show(io::IO, g::LatLonGrid{TX, TY, TZ}) where {TX, TY, TZ} "grid spacing (Δx, Δy, Δz): ", g.Δx, ", ", g.Δy, ", [min=", dzF_min, ", max=", dzF_max,"])") end -function short_show(grid::LatLonGrid{TX, TY, TZ}) where {TX, TY, TZ} - return "LatLonGrid{$TX, $TY, $TZ}(Nx=$(grid.Nx), Ny=$(grid.Ny), Nz=$(grid.Nz))" +function short_show(grid::LatLonGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} + return "LatLonGrid{$FT, $TX, $TY, $TZ}(Nx=$(grid.Nx), Ny=$(grid.Ny), Nz=$(grid.Nz))" end diff --git a/src/Grids/rectilinear_grid.jl b/src/Grids/rectilinear_grid.jl index 89df259f..65939d4f 100644 --- a/src/Grids/rectilinear_grid.jl +++ b/src/Grids/rectilinear_grid.jl @@ -1,17 +1,17 @@ -struct RectilinearGrid{TX, TY, TZ, R, A1, A3} <: AbstractGrid{TX, TY, TZ} +struct RectilinearGrid{FT, TX, TY, TZ} <: AbstractGrid{FT, TX, TY, TZ} # corrdinates at cell centers, unit: meter - xC::R - yC::R - zC::A1 + xC::Vector{FT} + yC::Vector{FT} + zC::Vector{FT} # corrdinates at cell faces, unit: meter - xF::R - yF::R - zF::A1 + xF::Vector{FT} + yF::Vector{FT} + zF::Vector{FT} # grid spacing, unit: meter - Δx::Float64 - Δy::Float64 - dzC::A1 - dzF::A1 + Δx::FT + Δy::FT + dzC::Vector{FT} + dzF::Vector{FT} # number of grid points Nx::Int Ny::Int @@ -21,11 +21,12 @@ struct RectilinearGrid{TX, TY, TZ, R, A1, A3} <: AbstractGrid{TX, TY, TZ} Hy::Int Hz::Int # landmask to indicate where is the land - landmask::A3 + landmask::AbstractArray{FT,3} end """ RectilinearGrid(;size, x, y, z, + FT = Float32, topology = (Periodic, Periodic, Bounded), landmask = nothing, halo = (2, 2, 2)) @@ -44,6 +45,7 @@ Keyword Arguments (Required) Keyword Arguments (Optional) ============================ +- `FT`: Floating point data type. Default: `Float32`. - `topology` : A 3-tuple specifying the topology of the domain. The topology can be either Periodic or Bounded in each direction. - `landmask` : a 3-dimentional array to indicate where the land is. @@ -53,9 +55,10 @@ Keyword Arguments (Optional) At least 2 halo points are needed for DST3FL advection scheme. """ function RectilinearGrid(;size, x, y, z, - topology = (Periodic, Periodic, Bounded), - landmask = nothing, - halo = (2, 2, 2)) + FT = Float32, + topology = (Periodic, Periodic, Bounded), + landmask = nothing, + halo = (2, 2, 2)) Nx, Ny, Nz = size Hx, Hy, Hz = halo TX, TY, TZ = validate_topology(topology) @@ -69,15 +72,15 @@ function RectilinearGrid(;size, x, y, z, Δx = (x₂ - x₁) / Nx Δy = (y₂ - y₁) / Ny - xF = range(-Hx * Δx, (Nx + Hx - 1) * Δx, length = Nx + 2 * Hx) - yF = range(-Hy * Δy, (Ny + Hy - 1) * Δy, length = Ny + 2 * Hy) + xF = collect(FT, range(-Hx * Δx, (Nx + Hx - 1) * Δx, length = Nx + 2 * Hx)) + yF = collect(FT, range(-Hy * Δy, (Ny + Hy - 1) * Δy, length = Ny + 2 * Hy)) - xC = range((0.5 - Hx) * Δx, (Nx + Hx - 0.5) * Δx, length = Nx + 2 * Hx) - yC = range((0.5 - Hy) * Δy, (Ny + Hy - 0.5) * Δy, length = Ny + 2 * Hy) + xC = collect(FT, range((0.5 - Hx) * Δx, (Nx + Hx - 0.5) * Δx, length = Nx + 2 * Hx)) + yC = collect(FT, range((0.5 - Hy) * Δy, (Ny + Hy - 0.5) * Δy, length = Ny + 2 * Hy)) if isa(z, Tuple{<:Number, <:Number}) z₁, z₂ = z - z = collect(range(z₁, z₂, length = Nz+1)) + z = collect(FT, range(z₁, z₂, length = Nz+1)) elseif isa(z, AbstractVector) z₁, z₂ = z[1], z[end] else @@ -87,12 +90,12 @@ function RectilinearGrid(;size, x, y, z, @assert z₁ > z₂ @assert Base.length(z) == Nz + 1 - zF = zeros(Nz+2Hz) - zC = zeros(Nz+2Hz) - dzF = zeros(Nz+2Hz) - dzC = zeros(Nz+2Hz) + zF = zeros(FT, Nz+2Hz) + zC = zeros(FT, Nz+2Hz) + dzF = zeros(FT, Nz+2Hz) + dzC = zeros(FT, Nz+2Hz) - zF[1+Hz:Nz+Hz+1] .= Float64.(z) + zF[1+Hz:Nz+Hz+1] .= FT.(z) zC[1+Hz:Nz+Hz] .= (zF[1+Hz:Nz+Hz] .+ zF[2+Hz:Nz+Hz+1]) ./ 2 dzF[1+Hz:Nz+Hz] .= zF[1+Hz:Nz+Hz] .- zF[2+Hz:Nz+Hz+1] dzC[1+Hz:Nz+Hz-1] .= zC[1+Hz:Nz+Hz-1] .- zC[2+Hz:Nz+Hz] @@ -112,19 +115,19 @@ function RectilinearGrid(;size, x, y, z, zC[i] = zC[i+1] + dzF[1] end - landmask = landmask_validation(landmask, Nx, Ny, Nz, Hx, Hy, Hz, TX, TY) + landmask = landmask_validation(landmask, Nx, Ny, Nz, Hx, Hy, Hz, FT, TX, TY) - return RectilinearGrid{TX, TY, TZ, typeof(xF), typeof(zF), typeof(landmask)}( + return RectilinearGrid{FT, TX, TY, TZ}( xC, yC, zC, xF, yF, zF, Δx, Δy, dzC, dzF, Nx, Ny, Nz, Hx, Hy, Hz, landmask) end -function show(io::IO, g::RectilinearGrid{TX, TY, TZ}) where {TX, TY, TZ} +function show(io::IO, g::RectilinearGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} xL, xR = g.xF[g.Hx+1], g.xF[g.Hx+1+g.Nx] yL, yR = g.yF[g.Hy+1], g.yF[g.Hy+1+g.Ny] zL, zR = g.zF[g.Hz+1], g.zF[g.Hz+1+g.Nz] dzF_min = minimum(g.dzF) dzF_max = maximum(g.dzF) - print(io, "RegularRectilinearGrid{$TX, $TY, $TZ}\n", + print(io, "RegularRectilinearGrid{$FT, $TX, $TY, $TZ}\n", "domain: x ∈ [$xL, $xR], y ∈ [$yL, $yR], z ∈ [$zL, $zR]\n", "topology (Tx, Ty, Tz): ", (TX, TY, TZ), '\n', "resolution (Nx, Ny, Nz): ", (g.Nx, g.Ny, g.Nz), '\n', @@ -132,8 +135,8 @@ function show(io::IO, g::RectilinearGrid{TX, TY, TZ}) where {TX, TY, TZ} "grid spacing (Δx, Δy, Δz): ", g.Δx, ", ", g.Δy, ", [min=", dzF_min, ", max=", dzF_max,"])") end -function short_show(grid::RectilinearGrid{TX, TY, TZ}) where {TX, TY, TZ} - return "RegularRectilinearGrid{$TX, $TY, $TZ}(Nx=$(grid.Nx), Ny=$(grid.Ny), Nz=$(grid.Nz))" +function short_show(grid::RectilinearGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} + return "RegularRectilinearGrid{$FT, $TX, $TY, $TZ}(Nx=$(grid.Nx), Ny=$(grid.Ny), Nz=$(grid.Nz))" end diff --git a/src/Grids/utils.jl b/src/Grids/utils.jl index 44b551b1..35d1c7f1 100644 --- a/src/Grids/utils.jl +++ b/src/Grids/utils.jl @@ -1,7 +1,7 @@ ##### ##### replace the storage place of the grid information based on the architecture ##### -function replace_grid_storage(arch::Architecture, grid::LatLonGrid{TX, TY, TZ}) where {TX, TY, TZ} +function replace_grid_storage(arch::Architecture, grid::LatLonGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} zF = grid.zF |> array_type(arch) zC = grid.zC |> array_type(arch) dxC = grid.dxC |> array_type(arch) @@ -16,18 +16,18 @@ function replace_grid_storage(arch::Architecture, grid::LatLonGrid{TX, TY, TZ}) Vol = grid.Vol |> array_type(arch) landmask = grid.landmask |> array_type(arch) - return LatLonGrid{TX, TY, TZ, typeof(grid.xF), typeof(zF), typeof(dxC), typeof(Vol)}( + return LatLonGrid{FT, TX, TY, TZ}( grid.xC, grid.yC, zC, grid.xF, grid.yF, zF, grid.Δx, grid.Δy, dxC, dyC, dzC, dxF, dyF, dzF, Ax, Ay, Az, Vol, grid.Nx, grid.Ny, grid.Nz, grid.Hx, grid.Hy, grid.Hz, landmask) end -function replace_grid_storage(arch::Architecture, grid::RectilinearGrid{TX, TY, TZ}) where {TX, TY, TZ} +function replace_grid_storage(arch::Architecture, grid::RectilinearGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} zF = grid.zF |> array_type(arch) zC = grid.zC |> array_type(arch) dzC = grid.dzC |> array_type(arch) dzF = grid.dzF |> array_type(arch) landmask = grid.landmask |> array_type(arch) - return RectilinearGrid{TX, TY, TZ, typeof(grid.xF), typeof(zF), typeof(landmask)}( + return RectilinearGrid{FT, TX, TY, TZ}( grid.xC, grid.yC, zC, grid.xF, grid.yF, zF, grid.Δx, grid.Δy, dzC, dzF, grid.Nx, grid.Ny, grid.Nz, grid.Hx, grid.Hy, grid.Hz, landmask) end @@ -36,8 +36,8 @@ end ##### adapt the grid struct to GPU ##### -Adapt.adapt_structure(to, grid::RectilinearGrid{TX, TY, TZ}) where {TX, TY, TZ} = - RectilinearGrid{TX, TY, TZ, typeof(grid.xF), typeof(Adapt.adapt(to, grid.zF)), typeof(Adapt.adapt(to, grid.landmask))}( +Adapt.adapt_structure(to, grid::RectilinearGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} = + RectilinearGrid{FT, TX, TY, TZ}( grid.xC, grid.yC, Adapt.adapt(to, grid.zC), grid.xF, grid.yF, @@ -49,9 +49,8 @@ Adapt.adapt_structure(to, grid::RectilinearGrid{TX, TY, TZ}) where {TX, TY, TZ} grid.Hx, grid.Hy, grid.Hz, Adapt.adapt(to, grid.landmask)) -Adapt.adapt_structure(to, grid::LatLonGrid{TX, TY, TZ}) where {TX, TY, TZ} = - LatLonGrid{TX, TY, TZ, typeof(grid.xF), typeof(Adapt.adapt(to, grid.zF)), - typeof(Adapt.adapt(to, grid.dxC)), typeof(Adapt.adapt(to, grid.Vol))}( +Adapt.adapt_structure(to, grid::LatLonGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} = + LatLonGrid{FT, TX, TY, TZ}( grid.xC, grid.yC, Adapt.adapt(to, grid.zC), grid.xF, grid.yF, @@ -74,9 +73,9 @@ Adapt.adapt_structure(to, grid::LatLonGrid{TX, TY, TZ}) where {TX, TY, TZ} = ##### ##### validate the land mask ##### -function landmask_validation(landmask, Nx, Ny, Nz, Hx, Hy, Hz, TX, TY) - if landmask == nothing - landmask = ones(Nx, Ny, Nz) +function landmask_validation(landmask, Nx, Ny, Nz, Hx, Hy, Hz, FT, TX, TY) + if isnothing(landmask) + landmask = ones(FT, Nx, Ny, Nz) else if Base.size(landmask) ≠ (Nx, Ny, Nz) throw(ArgumentError("landmask: grid mismatch, size(landmask) must equal to (grid.Nx, grid.Ny, grid.Nz).")) diff --git a/src/Model/models.jl b/src/Model/models.jl index e611710b..185e1dd4 100644 --- a/src/Model/models.jl +++ b/src/Model/models.jl @@ -1,9 +1,10 @@ mutable struct PlanktonModel arch::Architecture # architecture on which models will run - t::Float64 # time in second - iteration::Int64 # model interation - individuals::individuals # initial individuals generated by `setup_agents` - nutrients::NamedTuple # initial nutrient fields + FT::DataType # floating point data type + t::AbstractFloat # time in second + iteration::Int # model interation + individuals::individuals # individuals + nutrients::NamedTuple # nutrient fields grid::AbstractGrid # grid information bgc_params::Dict # biogeochemical parameter set timestepper::timestepper # operating Tuples and arrays for timestep @@ -12,6 +13,7 @@ end """ PlanktonModel(arch::Architecture, grid::AbstractGrid; + FT = Float32, mode = QuotaMode(), N_species = 1, N_individual = [1024], @@ -31,6 +33,7 @@ Keyword Arguments (Required) Keyword Arguments (Optional) ============================ +- `FT`: Floating point data type. Default: `Float32`. - `mode` : Phytoplankton physiology mode, choose among CarbonMode(), QuotaMode(), or MacroMolecularMode(). - `N_species` : Number of species. - `N_individual` : Number of individuals per species, should be a vector with `N_species` elements. @@ -46,6 +49,7 @@ Keyword Arguments (Optional) - `t` : Model time, start from 0 by default, in second. """ function PlanktonModel(arch::Architecture, grid::AbstractGrid; + FT = Float32, mode = QuotaMode(), N_species::Int64 = 1, N_individual::Vector{Int64} = [1024], @@ -84,15 +88,15 @@ function PlanktonModel(arch::Architecture, grid::AbstractGrid; end end - inds = generate_individuals(phyt_params_final, arch, N_species, N_individual, max_individuals, grid_d, mode) + inds = generate_individuals(phyt_params_final, arch, N_species, N_individual, max_individuals, FT, grid_d, mode) - nutrients = generate_nutrients(arch, grid_d, nut_initial) + nutrients = generate_nutrients(arch, grid_d, nut_initial, FT) - ts = timestepper(arch, grid_d, max_individuals) + ts = timestepper(arch, FT, grid_d, max_individuals) iteration = 0 - model = PlanktonModel(arch, t, iteration, inds, nutrients, grid_d, bgc_params_final, ts, mode) + model = PlanktonModel(arch, FT, t, iteration, inds, nutrients, grid_d, bgc_params_final, ts, mode) return model end @@ -102,6 +106,7 @@ function show(io::IO, model::PlanktonModel) N = Int(dot(model.individuals.phytos.sp1.data.ac,model.individuals.phytos.sp1.data.ac)) cap = length(model.individuals.phytos.sp1.data.ac) print(io, "PlanktonModel:\n", + "├── floating point data type: $(model.FT)\n", "├── grid: $(short_show(model.grid))\n", "├── $(model.mode) is selected for phytoplankton physiology\n", "├── individuals: $(Nsp) phytoplankton species with $(N) individuals for each species\n", diff --git a/src/Model/timestepper.jl b/src/Model/timestepper.jl index 1d09cd3d..9b3fbd50 100644 --- a/src/Model/timestepper.jl +++ b/src/Model/timestepper.jl @@ -15,33 +15,34 @@ mutable struct timestepper nuts::AbstractArray # a StructArray of nutrients of each individual end -function timestepper(arch::Architecture, g::AbstractGrid, maxN) - vel₀ = (u = Field(arch, g), v = Field(arch, g), w = Field(arch, g)) - vel½ = (u = Field(arch, g), v = Field(arch, g), w = Field(arch, g)) - vel₁ = (u = Field(arch, g), v = Field(arch, g), w = Field(arch, g)) +function timestepper(arch::Architecture, FT::DataType, g::AbstractGrid, maxN) + vel₀ = (u = Field(arch, g, FT), v = Field(arch, g, FT), w = Field(arch, g, FT)) + vel½ = (u = Field(arch, g, FT), v = Field(arch, g, FT), w = Field(arch, g, FT)) + vel₁ = (u = Field(arch, g, FT), v = Field(arch, g, FT), w = Field(arch, g, FT)) - Gcs = nutrients_init(arch, g) - nut_temp = nutrients_init(arch, g) - plk = nutrients_init(arch, g) + Gcs = nutrients_init(arch, g, FT) + nut_temp = nutrients_init(arch, g, FT) + plk = nutrients_init(arch, g, FT) - par = zeros(g.Nx+g.Hx*2, g.Ny+g.Hy*2, g.Nz+g.Hz*2) |> array_type(arch) - Chl = zeros(g.Nx+g.Hx*2, g.Ny+g.Hy*2, g.Nz+g.Hz*2) |> array_type(arch) - pop = zeros(g.Nx+g.Hx*2, g.Ny+g.Hy*2, g.Nz+g.Hz*2) |> array_type(arch) + par = zeros(FT, g.Nx+g.Hx*2, g.Ny+g.Hy*2, g.Nz+g.Hz*2) |> array_type(arch) + Chl = zeros(FT, g.Nx+g.Hx*2, g.Ny+g.Hy*2, g.Nz+g.Hz*2) |> array_type(arch) + pop = zeros(FT, g.Nx+g.Hx*2, g.Ny+g.Hy*2, g.Nz+g.Hz*2) |> array_type(arch) - temp = zeros(g.Nx+g.Hx*2, g.Ny+g.Hy*2, g.Nz+g.Hz*2) |> array_type(arch) - PARF = zeros(g.Nx, g.Ny) |> array_type(arch) + temp = zeros(FT, g.Nx+g.Hx*2, g.Ny+g.Hy*2, g.Nz+g.Hz*2) |> array_type(arch) + PARF = zeros(FT, g.Nx, g.Ny) |> array_type(arch) - rnd = StructArray(x = zeros(maxN), y = zeros(maxN), z = zeros(maxN)) + rnd = StructArray(x = zeros(FT, maxN), y = zeros(FT, maxN), z = zeros(FT, maxN)) rnd_d = replace_storage(array_type(arch), rnd) - velos = StructArray(x = zeros(maxN), y = zeros(maxN), z = zeros(maxN), - u1 = zeros(maxN), v1 = zeros(maxN), w1 = zeros(maxN), - u2 = zeros(maxN), v2 = zeros(maxN), w2 = zeros(maxN), + velos = StructArray(x = zeros(FT, maxN), y = zeros(FT, maxN), z = zeros(FT, maxN), + u1 = zeros(FT, maxN), v1 = zeros(FT, maxN), w1 = zeros(FT, maxN), + u2 = zeros(FT, maxN), v2 = zeros(FT, maxN), w2 = zeros(FT, maxN), ) velos_d = replace_storage(array_type(arch), velos) - nuts = StructArray(NH4 = zeros(maxN), NO3 = zeros(maxN), PO4 = zeros(maxN), DOC = zeros(maxN), - par = zeros(maxN), T = zeros(maxN), pop = zeros(maxN)) + nuts = StructArray(NH4 = zeros(FT, maxN), NO3 = zeros(FT, maxN), PO4 = zeros(FT, maxN), + DOC = zeros(FT, maxN), + par = zeros(FT, maxN), T = zeros(FT, maxN), pop = zeros(FT, maxN)) nuts_d = replace_storage(array_type(arch), nuts) ts = timestepper(Gcs, nut_temp, vel₀, vel½, vel₁, PARF, temp, plk, par, Chl, pop, rnd_d, velos_d, nuts_d) diff --git a/src/Plankton/Advection/plankton_advection_kernels.jl b/src/Plankton/Advection/plankton_advection_kernels.jl index cb851c01..bd111ff5 100644 --- a/src/Plankton/Advection/plankton_advection_kernels.jl +++ b/src/Plankton/Advection/plankton_advection_kernels.jl @@ -10,7 +10,7 @@ end return x end -@kernel function particle_boundaries_kernel!(plank, ac, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ} +@kernel function particle_boundaries_kernel!(plank, ac, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} i = @index(Global) @inbounds plank.x[i] = particle_boundary_condition(plank.x[i], 0, g.Nx, TX()) * ac[i] @inbounds plank.y[i] = particle_boundary_condition(plank.y[i], 0, g.Ny, TY()) * ac[i] diff --git a/src/Plankton/CarbonMode/plankton_generation.jl b/src/Plankton/CarbonMode/plankton_generation.jl index ba83e44a..de7ce12b 100644 --- a/src/Plankton/CarbonMode/plankton_generation.jl +++ b/src/Plankton/CarbonMode/plankton_generation.jl @@ -1,13 +1,13 @@ -function construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN) - rawdata = StructArray(x = zeros(maxN), y = zeros(maxN), z = zeros(maxN), +function construct_plankton(arch::Architecture, sp::Int, params::Dict, maxN::Int, FT::DataType) + rawdata = StructArray(x = zeros(FT, maxN), y = zeros(FT, maxN), z = zeros(FT, maxN), xi = zeros(Int,maxN), yi = zeros(Int,maxN), zi = zeros(Int,maxN), - iS = zeros(maxN), Sz = zeros(maxN), - Bm = zeros(maxN), Bd = zeros(maxN), Chl = zeros(maxN), - gen = zeros(maxN), age = zeros(maxN), - ac = zeros(maxN), idx = zeros(maxN), - PS = zeros(maxN), BS = zeros(maxN), RS = zeros(maxN), - TD = zeros(maxN), RP = zeros(maxN), - graz= zeros(maxN), mort= zeros(maxN), dvid= zeros(maxN) + iS = zeros(FT, maxN), Sz = zeros(FT, maxN), + Bm = zeros(FT, maxN), Bd = zeros(FT, maxN), Chl = zeros(FT, maxN), + gen = zeros(FT, maxN), age = zeros(FT, maxN), + ac = zeros(FT, maxN), idx = zeros(FT, maxN), + PS = zeros(FT, maxN), BS = zeros(FT, maxN), RS = zeros(FT, maxN), + TD = zeros(FT, maxN), RP = zeros(FT, maxN), + graz= zeros(FT, maxN), mort= zeros(FT, maxN), dvid= zeros(FT, maxN) ) data = replace_storage(array_type(arch), rawdata) @@ -27,7 +27,7 @@ function construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN) return plankton(data, p) end -function generate_plankton!(plank, N::Int64, g::AbstractGrid, arch::Architecture) +function generate_plankton!(plank, N::Int, g::AbstractGrid, arch::Architecture) mean = plank.p.mean var = plank.p.var Cquota = plank.p.Cquota diff --git a/src/Plankton/MacroMolecularMode/plankton_generation.jl b/src/Plankton/MacroMolecularMode/plankton_generation.jl index 9daa62f8..5d2fe607 100644 --- a/src/Plankton/MacroMolecularMode/plankton_generation.jl +++ b/src/Plankton/MacroMolecularMode/plankton_generation.jl @@ -1,15 +1,15 @@ -function construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN) - rawdata = StructArray(x = zeros(maxN), y = zeros(maxN), z = zeros(maxN), - xi = zeros(Int,maxN), yi = zeros(Int,maxN), zi = zeros(Int,maxN), - CH = zeros(maxN), NST = zeros(maxN), PST = zeros(maxN), - PRO = zeros(maxN), DNA = zeros(maxN), RNA = zeros(maxN), - Chl = zeros(maxN), gen = zeros(maxN), age = zeros(maxN), - ac = zeros(maxN), idx = zeros(maxN), - PS = zeros(maxN), VDOC = zeros(maxN), VNH4 = zeros(maxN), - VNO3 = zeros(maxN), VPO4 = zeros(maxN), ρChl = zeros(maxN), - resp = zeros(maxN), S_PRO= zeros(maxN), S_DNA= zeros(maxN), - S_RNA= zeros(maxN), exu = zeros(maxN), - graz = zeros(maxN), mort = zeros(maxN), dvid = zeros(maxN) +function construct_plankton(arch::Architecture, sp::Int, params::Dict, maxN::Int, FT::DataType) + rawdata = StructArray(x = zeros(FT, maxN), y = zeros(FT, maxN), z = zeros(FT, maxN), + xi = zeros(Int,maxN), yi = zeros(Int,maxN), zi = zeros(Int,maxN), + CH = zeros(FT, maxN), NST = zeros(FT, maxN), PST = zeros(FT, maxN), + PRO = zeros(FT, maxN), DNA = zeros(FT, maxN), RNA = zeros(FT, maxN), + Chl = zeros(FT, maxN), gen = zeros(FT, maxN), age = zeros(FT, maxN), + ac = zeros(FT, maxN), idx = zeros(FT, maxN), + PS = zeros(FT, maxN), VDOC = zeros(FT, maxN), VNH4 = zeros(FT, maxN), + VNO3 = zeros(FT, maxN), VPO4 = zeros(FT, maxN), ρChl = zeros(FT, maxN), + resp = zeros(FT, maxN), S_PRO= zeros(FT, maxN), S_DNA= zeros(FT, maxN), + S_RNA= zeros(FT, maxN), exu = zeros(FT, maxN), + graz = zeros(FT, maxN), mort = zeros(FT, maxN), dvid = zeros(FT, maxN) ) data = replace_storage(array_type(arch), rawdata) @@ -35,7 +35,7 @@ function construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN) return plankton(data, p) end -function generate_plankton!(plank, N::Int64, g::AbstractGrid, arch::Architecture) +function generate_plankton!(plank, N::Int, g::AbstractGrid, arch::Architecture) mean = plank.p.mean var = plank.p.var C_DNA = plank.p.C_DNA diff --git a/src/Plankton/Plankton.jl b/src/Plankton/Plankton.jl index 47752b3f..ab0c025b 100644 --- a/src/Plankton/Plankton.jl +++ b/src/Plankton/Plankton.jl @@ -21,7 +21,7 @@ using PlanktonIndividuals: AbstractMode, CarbonMode, QuotaMode, MacroMolecularMo ##### ##### generate individuals of multiple species ##### -function generate_individuals(params::Dict, arch::Architecture, Nsp, N, maxN, g::AbstractGrid, mode::AbstractMode) +function generate_individuals(params::Dict, arch::Architecture, Nsp::Int, N::Vector{Int}, maxN::Int, FT::DataType, g::AbstractGrid, mode::AbstractMode) plank_names = Symbol[] plank_data=[] @@ -31,7 +31,7 @@ function generate_individuals(params::Dict, arch::Architecture, Nsp, N, maxN, g: for i in 1:Nsp name = Symbol("sp"*string(i)) - plank = construct_plankton(arch, i, params, maxN, mode::AbstractMode) + plank = construct_plankton(arch, i, params, maxN, FT, mode::AbstractMode) generate_plankton!(plank, N[i], g, arch, mode) push!(plank_names, name) push!(plank_data, plank) @@ -54,14 +54,14 @@ import .MacroMolecular ##### ##### some workarounds for function names ##### -construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN, mode::MacroMolecularMode) = - MacroMolecular.construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN) +construct_plankton(arch::Architecture, sp::Int, params::Dict, maxN::Int, FT::DataType, mode::MacroMolecularMode) = + MacroMolecular.construct_plankton(arch::Architecture, sp::Int, params::Dict, maxN::Int, FT::DataType) -construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN, mode::QuotaMode) = - Quota.construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN) +construct_plankton(arch::Architecture, sp::Int, params::Dict, maxN::Int, FT::DataType, mode::QuotaMode) = + Quota.construct_plankton(arch::Architecture, sp::Int, params::Dict, maxN::Int, FT::DataType) -construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN, mode::CarbonMode) = - Carbon.construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN) +construct_plankton(arch::Architecture, sp::Int, params::Dict, maxN::Int, FT::DataType, mode::CarbonMode) = + Carbon.construct_plankton(arch::Architecture, sp::Int, params::Dict, maxN::Int, FT::DataType) generate_plankton!(plank, N::Int64, g::AbstractGrid, arch::Architecture, mode::MacroMolecularMode) = MacroMolecular.generate_plankton!(plank, N::Int64, g::AbstractGrid, arch::Architecture) diff --git a/src/Plankton/QuotaMode/plankton_generation.jl b/src/Plankton/QuotaMode/plankton_generation.jl index d0031737..381824ee 100644 --- a/src/Plankton/QuotaMode/plankton_generation.jl +++ b/src/Plankton/QuotaMode/plankton_generation.jl @@ -1,14 +1,14 @@ -function construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN) - rawdata = StructArray(x = zeros(maxN), y = zeros(maxN), z = zeros(maxN), - xi = zeros(Int,maxN), yi = zeros(Int,maxN), zi = zeros(Int,maxN), - iS = zeros(maxN), Sz = zeros(maxN), Bm = zeros(maxN), - Cq = zeros(maxN), Nq = zeros(maxN), Pq = zeros(maxN), - Chl = zeros(maxN), gen = zeros(maxN), age = zeros(maxN), - ac = zeros(maxN), idx = zeros(maxN), - PS = zeros(maxN), VDOC = zeros(maxN), VNH4 = zeros(maxN), - VNO3 = zeros(maxN), VPO4 = zeros(maxN), ρChl = zeros(maxN), - resp = zeros(maxN), BS = zeros(maxN), exu = zeros(maxN), - graz = zeros(maxN), mort = zeros(maxN), dvid = zeros(maxN) +function construct_plankton(arch::Architecture, sp::Int, params::Dict, maxN::Int, FT::DataType) + rawdata = StructArray(x = zeros(FT, maxN), y = zeros(FT, maxN), z = zeros(FT, maxN), + xi = zeros(Int,maxN), yi = zeros(Int,maxN), zi = zeros(Int,maxN), + iS = zeros(FT, maxN), Sz = zeros(FT, maxN), Bm = zeros(FT, maxN), + Cq = zeros(FT, maxN), Nq = zeros(FT, maxN), Pq = zeros(FT, maxN), + Chl = zeros(FT, maxN), gen = zeros(FT, maxN), age = zeros(FT, maxN), + ac = zeros(FT, maxN), idx = zeros(FT, maxN), + PS = zeros(FT, maxN), VDOC = zeros(FT, maxN), VNH4 = zeros(FT, maxN), + VNO3 = zeros(FT, maxN), VPO4 = zeros(FT, maxN), ρChl = zeros(FT, maxN), + resp = zeros(FT, maxN), BS = zeros(FT, maxN), exu = zeros(FT, maxN), + graz = zeros(FT, maxN), mort = zeros(FT, maxN), dvid = zeros(FT, maxN) ) data = replace_storage(array_type(arch), rawdata) @@ -31,7 +31,7 @@ function construct_plankton(arch::Architecture, sp::Int64, params::Dict, maxN) return plankton(data, p) end -function generate_plankton!(plank, N::Int64, g::AbstractGrid, arch::Architecture) +function generate_plankton!(plank, N::Int, g::AbstractGrid, arch::Architecture) mean = plank.p.mean var = plank.p.var Cquota = plank.p.Cquota diff --git a/src/Simulation/simulations.jl b/src/Simulation/simulations.jl index 987ead90..e61e52af 100644 --- a/src/Simulation/simulations.jl +++ b/src/Simulation/simulations.jl @@ -1,18 +1,18 @@ mutable struct PlanktonInput - temp::AbstractArray{Float64,4} # temperature - PARF::AbstractArray{Float64,3} # PARF - vels::NamedTuple # velocity fields for nutrients and individuals - ΔT_vel::Float64 # time step of velocities provided - ΔT_PAR::Float64 # time step of surface PAR provided - ΔT_temp::Float64 # time step of temperature provided + temp::AbstractArray{AbstractFloat,4} # temperature + PARF::AbstractArray{AbstractFloat,3} # PARF + vels::NamedTuple # velocity fields + ΔT_vel::Float64 # time step of velocities provided + ΔT_PAR::Float64 # time step of surface PAR provided + ΔT_temp::Float64 # time step of temperature provided end mutable struct PlanktonSimulation model::PlanktonModel # Model object input::PlanktonInput # model input, temp, PAR, and velocities diags::Union{PlanktonDiagnostics,Nothing} # diagnostics - ΔT::Float64 # model time step - iterations::Int64 # run the simulation for this number of iterations + ΔT::Float64 # model time step + iterations::Int # run the simulation for this number of iterations output_writer::Union{PlanktonOutputWriter,Nothing} # Output writer end @@ -48,7 +48,7 @@ Keyword Arguments (Optional) - `ΔT_temp` : time step of temperature provided externally (in seconds). - `output_writer` : Output writer of the simulation generated by `PlanktonOutputWriter`. """ -function PlanktonSimulation(model::PlanktonModel; ΔT::Float64, iterations::Int64, +function PlanktonSimulation(model::PlanktonModel; ΔT::Float64, iterations::Int, PARF = default_PARF(model.grid, ΔT, iterations), temp = default_temperature(model.grid, ΔT, iterations), diags = nothing, @@ -63,7 +63,7 @@ function PlanktonSimulation(model::PlanktonModel; ΔT::Float64, iterations::Int6 validate_bcs(model.nutrients, model.grid, iterations) - if diags == nothing + if isnothing(diags) diags = PlanktonDiagnostics(model) end diff --git a/src/Simulation/utils.jl b/src/Simulation/utils.jl index 04e60e2a..ca18aa4f 100644 --- a/src/Simulation/utils.jl +++ b/src/Simulation/utils.jl @@ -1,8 +1,8 @@ """ - vel_copy!(vel::NamedTuple, u, v, w, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ} + vel_copy!(vel::NamedTuple, u, v, w, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} Copy external velocities into `PlanktonModel` """ -function vel_copy!(vel::NamedTuple, u, v, w, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ} +function vel_copy!(vel::NamedTuple, u, v, w, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} copy_interior_u!(vel.u.data, u, g, TX()) copy_interior_v!(vel.v.data, v, g, TY()) copy_interior_w!(vel.w.data, w, g, TZ()) @@ -33,7 +33,7 @@ end copyto!(view(c, g.Hx+1:g.Hx+g.Nx, g.Hy+1:g.Hy+g.Ny, g.Hz+1:g.Hz+g.Nz+1), t) end -function validate_temp(sim::PlanktonSimulation, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ} +function validate_temp(sim::PlanktonSimulation, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} temp_size = (g.Nx, g.Ny, g.Nz) validation = true @@ -55,7 +55,7 @@ function validate_temp(sim::PlanktonSimulation, g::AbstractGrid{TX, TY, TZ}) whe return validation end -function validate_PARF(sim::PlanktonSimulation, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ} +function validate_PARF(sim::PlanktonSimulation, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} PARF_size = (g.Nx, g.Ny) validation = true @@ -77,7 +77,7 @@ function validate_PARF(sim::PlanktonSimulation, g::AbstractGrid{TX, TY, TZ}) whe return validation end -function validate_velocity(sim::PlanktonSimulation, g::AbstractGrid{TX, TY, TZ}) where {TX, TY, TZ} +function validate_velocity(sim::PlanktonSimulation, g::AbstractGrid{FT, TX, TY, TZ}) where {FT, TX, TY, TZ} if sim.input.vels ≠ (;) u_size = (g.Nx, g.Ny, g.Nz) v_size = (g.Nx, g.Ny, g.Nz) diff --git a/test/field_test.jl b/test/field_test.jl index 8bc86f86..bc8461c0 100644 --- a/test/field_test.jl +++ b/test/field_test.jl @@ -11,7 +11,7 @@ function test_fields() @test nut.DIC.data == zeros(8,10,6) @test interior(nut.DIC.data, grid) == zeros(4,6,2) - nuts = generate_nutrients(CPU(), grid, default_nut_init()) + nuts = generate_nutrients(CPU(), grid, default_nut_init(), Float32) @test maximum(nuts.DIC.data) < 23.0 @test minimum(nuts.DIC.data) > 17.0 @@ -22,7 +22,7 @@ function test_fields() end function test_fill_halos() grid = RectilinearGrid(size = (4,6,2), x = (0,12), y = (0,12), z = (0,-8)) - nuts = generate_nutrients(CPU(), grid, default_nut_init()) + nuts = generate_nutrients(CPU(), grid, default_nut_init(), Float32) Nx,Ny,Nz = grid.Nx, grid.Ny, grid.Nz Hx,Hy,Hz = grid.Hx, grid.Hy, grid.Hz @@ -110,11 +110,11 @@ end function test_boundary_conditions() grid = RectilinearGrid(size = (4, 4, 4), x = (0,32), y = (0,32), z = (0,-32)) model = PlanktonModel(CPU(), grid; mode = CarbonMode()) - set_bc!(model, :DIC, :west, 0.1) + set_bc!(model; tracer = :DIC, pos = :west, bc_value = 0.1) @test model.nutrients.DIC.bc.west == 0.1 - set_bc!(model, :DIC, :west, ones(4,4)) + set_bc!(model; tracer = :DIC, pos = :west, bc_value = ones(4,4)) @test model.nutrients.DIC.bc.west == ones(4,4) - set_bc!(model, :DIC, :west, ones(4,4,10)) + set_bc!(model; tracer = :DIC, pos = :west, bc_value = ones(4,4,10)) @test model.nutrients.DIC.bc.west == ones(4,4,10) Gcs = nutrients_init(CPU(), grid) diff --git a/test/output_test.jl b/test/output_test.jl index baf9f25b..33ccb64b 100644 --- a/test/output_test.jl +++ b/test/output_test.jl @@ -27,12 +27,12 @@ function test_output() ds = jldopen("result/diags_part1.jld2") time_length = length(keys(ds["timeseries/t"])) - @test time_length == 2 + @test time_length > 1 close(ds) ds1 = jldopen("result/plankton_part1.jld2") time_length1 = length(keys(ds1["timeseries/t"])) - @test time_length1 == 1 + @test time_length1 > 1 close(ds1) rm("result", recursive=true)