Skip to content

Commit

Permalink
Merge branch 'main' into sk/ke_kernel_opt
Browse files Browse the repository at this point in the history
  • Loading branch information
sriharshakandala committed Jul 17, 2023
2 parents 05f9e5b + 0e7a5f9 commit da868a1
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 54 deletions.
8 changes: 8 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,14 @@ steps:
key: unit_field
command: "julia --color=yes --check-bounds=yes --project=test test/Fields/field.jl"

- label: "Unit: field cuda"
key: unit_field_cuda
command:
- "julia --project -e 'using CUDA; CUDA.versioninfo()'"
- "julia --color=yes --check-bounds=yes --project=test test/Fields/field.jl"
agents:
slurm_gpus: 1

- label: "Unit: fielddiffeq"
key: unit_fielddiffeq
command: "julia --color=yes --check-bounds=yes --project=test test/Fields/fielddiffeq.jl"
Expand Down
30 changes: 7 additions & 23 deletions examples/hybrid/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,27 +99,11 @@ else
)
end
p = get_cache(ᶜlocal_geometry, ᶠlocal_geometry, Y, dt, upwinding_mode)
if ode_algorithm <: Union{
OrdinaryDiffEq.OrdinaryDiffEqImplicitAlgorithm,
OrdinaryDiffEq.OrdinaryDiffEqAdaptiveImplicitAlgorithm,
}
use_transform = !(ode_algorithm in (Rosenbrock23, Rosenbrock32))
W = SchurComplementW(Y, use_transform, jacobian_flags, test_implicit_solver)
jac_kwargs =
use_transform ? (; jac_prototype = W, Wfact_t = Wfact!) :
(; jac_prototype = W, Wfact = Wfact!)

alg_kwargs = (; linsolve = linsolve!)
if ode_algorithm <: Union{
OrdinaryDiffEq.OrdinaryDiffEqNewtonAlgorithm,
OrdinaryDiffEq.OrdinaryDiffEqNewtonAdaptiveAlgorithm,
}
alg_kwargs =
(; alg_kwargs..., nlsolve = NLNewton(; max_iter = max_newton_iters))
end
else
jac_kwargs = alg_kwargs = (;)
end

include("ode_config.jl")

ode_algo =
ode_configuration(FT; ode_name = string(ode_algorithm), max_newton_iters)

if haskey(ENV, "OUTPUT_DIR")
output_dir = ENV["OUTPUT_DIR"]
Expand Down Expand Up @@ -164,7 +148,7 @@ callback =
problem = SplitODEProblem(
ODEFunction(
implicit_tendency!;
jac_kwargs...,
jac_kwargs(ode_algo, Y, jacobian_flags)...,
tgrad = (∂Y∂t, Y, p, t) -> (∂Y∂t .= FT(0)),
),
remaining_tendency!,
Expand All @@ -174,7 +158,7 @@ problem = SplitODEProblem(
)
integrator = OrdinaryDiffEq.init(
problem,
ode_algorithm(; alg_kwargs...);
ode_algo;
saveat = dt_save_to_sol == 0 ? [] : dt_save_to_sol,
callback = callback,
dt = dt,
Expand Down
88 changes: 88 additions & 0 deletions examples/hybrid/ode_config.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import DiffEqBase
import OrdinaryDiffEq as ODE
import ClimaTimeSteppers as CTS

is_explicit_CTS_algo_type(alg_or_tableau) =
alg_or_tableau <: CTS.ERKAlgorithmName

is_imex_CTS_algo_type(alg_or_tableau) =
alg_or_tableau <: CTS.IMEXARKAlgorithmName

is_implicit_type(::typeof(ODE.IMEXEuler)) = true
is_implicit_type(alg_or_tableau) =
alg_or_tableau <: Union{
ODE.OrdinaryDiffEqImplicitAlgorithm,
ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm,
} || is_imex_CTS_algo_type(alg_or_tableau)

is_ordinary_diffeq_newton(::typeof(ODE.IMEXEuler)) = true
is_ordinary_diffeq_newton(alg_or_tableau) =
alg_or_tableau <: Union{
ODE.OrdinaryDiffEqNewtonAlgorithm,
ODE.OrdinaryDiffEqNewtonAdaptiveAlgorithm,
}

is_imex_CTS_algo(::CTS.IMEXAlgorithm) = true
is_imex_CTS_algo(::DiffEqBase.AbstractODEAlgorithm) = false

is_implicit(::ODE.OrdinaryDiffEqImplicitAlgorithm) = true
is_implicit(::ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm) = true
is_implicit(ode_algo) = is_imex_CTS_algo(ode_algo)

is_rosenbrock(::ODE.Rosenbrock23) = true
is_rosenbrock(::ODE.Rosenbrock32) = true
is_rosenbrock(::DiffEqBase.AbstractODEAlgorithm) = false
use_transform(ode_algo) =
!(is_imex_CTS_algo(ode_algo) || is_rosenbrock(ode_algo))

function jac_kwargs(ode_algo, Y, jacobi_flags)
if is_implicit(ode_algo)
W = SchurComplementW(Y, use_transform(ode_algo), jacobi_flags)
if use_transform(ode_algo)
return (; jac_prototype = W, Wfact_t = Wfact!)
else
return (; jac_prototype = W, Wfact = Wfact!)
end
else
return NamedTuple()
end
end

function ode_configuration(
::Type{FT};
ode_name::Union{String, Nothing} = nothing,
max_newton_iters = nothing,
) where {FT}
if occursin(".", ode_name)
ode_name = split(ode_name, ".")[end]
end
ode_sym = Symbol(ode_name)
alg_or_tableau = if hasproperty(ODE, ode_sym)
@warn "apply_limiter flag is ignored for OrdinaryDiffEq algorithms"
getproperty(ODE, ode_sym)
else
getproperty(CTS, ode_sym)
end
@info "Using ODE config: `$alg_or_tableau`"

if is_explicit_CTS_algo_type(alg_or_tableau)
return CTS.ExplicitAlgorithm(alg_or_tableau())
elseif !is_implicit_type(alg_or_tableau)
return alg_or_tableau()
elseif is_ordinary_diffeq_newton(alg_or_tableau)
if max_newton_iters == 1
error("OridinaryDiffEq requires at least 2 Newton iterations")
end
# κ like a relative tolerance; its default value in ODE is 0.01
nlsolve = ODE.NLNewton(;
κ = max_newton_iters == 2 ? Inf : 0.01,
max_iter = max_newton_iters,
)
return alg_or_tableau(; linsolve = linsolve!, nlsolve)
elseif is_imex_CTS_algo_type(alg_or_tableau)
newtons_method = CTS.NewtonsMethod(; max_iters = max_newton_iters)
return CTS.IMEXAlgorithm(alg_or_tableau(), newtons_method)
else
return alg_or_tableau(; linsolve = linsolve!)
end
end
21 changes: 12 additions & 9 deletions examples/hybrid/tuning/mwe_tune_ke.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,18 @@ function mwe_compute_kinetic_energy()
# horizontal space
domain = Domains.SphereDomain(R)
horizontal_mesh = Meshes.EquiangularCubedSphere(domain, h_elem)
#=
horizontal_topology = Topologies.Topology2D(
context,
horizontal_mesh,
Topologies.spacefillingcurve(horizontal_mesh),
)
=#
horizontal_topology_cpu = Topologies.Topology2D(
context_cpu,
horizontal_mesh,
Topologies.spacefillingcurve(horizontal_mesh),
)
quad = Spaces.Quadratures.GLL{npoly + 1}()
#h_space = Spaces.SpectralElementSpace2D(horizontal_topology, quad)
h_space = Spaces.SpectralElementSpace2D(horizontal_topology, quad)
h_space_cpu = Spaces.SpectralElementSpace2D(horizontal_topology_cpu, quad)

# vertical space
Expand All @@ -110,20 +108,18 @@ function mwe_compute_kinetic_energy()
boundary_tags = (:bottom, :top),
)
z_mesh = Meshes.IntervalMesh(z_domain, nelems = z_elem)
#z_topology = Topologies.IntervalTopology(context, z_mesh)
z_topology = Topologies.IntervalTopology(context, z_mesh)
z_topology_cpu = Topologies.IntervalTopology(context_cpu, z_mesh)

#z_center_space = Spaces.CenterFiniteDifferenceSpace(z_topology)
z_center_space = Spaces.CenterFiniteDifferenceSpace(z_topology)
z_center_space_cpu = Spaces.CenterFiniteDifferenceSpace(z_topology_cpu)

#z_face_space = Spaces.FaceFiniteDifferenceSpace(z_topology)
z_face_space = Spaces.FaceFiniteDifferenceSpace(z_topology)
z_face_space_cpu = Spaces.FaceFiniteDifferenceSpace(z_topology_cpu)

#=
hv_center_space =
Spaces.ExtrudedFiniteDifferenceSpace(h_space, z_center_space)
hv_face_space = Spaces.FaceExtrudedFiniteDifferenceSpace(hv_center_space)
=#

hv_center_space_cpu =
Spaces.ExtrudedFiniteDifferenceSpace(h_space_cpu, z_center_space_cpu)
Expand All @@ -137,9 +133,16 @@ function mwe_compute_kinetic_energy()
uᵥ_cpu = face_initial_condition(ᶠlocal_geometry_cpu)
κ_cpu = init_scalar_field(hv_center_space_cpu)

# GPU
ᶜlocal_geometry = Fields.local_geometry_field(hv_center_space)
ᶠlocal_geometry = Fields.local_geometry_field(hv_face_space)
uₕ = center_initial_condition(ᶜlocal_geometry, R)
uᵥ = face_initial_condition(ᶠlocal_geometry)
κ = init_scalar_field(hv_center_space)


# compute kinetic energy
compute_kinetic_ca!(κ_cpu, uₕ_cpu, uᵥ_cpu)
compute_kinetic_ca!(κ, uₕ, uᵥ)

end

Expand Down
63 changes: 41 additions & 22 deletions test/Fields/field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ function spectral_space_2D(; n1 = 1, n2 = 1, Nij = 4)
x2boundary = (:south, :north),
)
mesh = Meshes.RectilinearMesh(domain, n1, n2)
device = ClimaComms.CPUSingleThreaded()
grid_topology =
Topologies.Topology2D(ClimaComms.SingletonCommsContext(), mesh)
Topologies.Topology2D(ClimaComms.SingletonCommsContext(device), mesh)

quad = Spaces.Quadratures.GLL{Nij}()
space = Spaces.SpectralElementSpace2D(grid_topology, quad)
Expand Down Expand Up @@ -101,11 +102,17 @@ function pow_n(f)
end
@testset "Broadcasting with ^n" begin
FT = Float32
for space in TU.all_spaces(FT)
device = ClimaComms.CPUSingleThreaded() # fill is broken on gpu
context = ClimaComms.SingletonCommsContext(device)
for space in TU.all_spaces(FT; context)
f = fill((; x = FT(1)), space)
pow_n(f) # Compile first
p_allocated = @allocated pow_n(f)
@test p_allocated == 0
if space isa Spaces.SpectralElementSpace1D
@test p_allocated == 0
else
@test p_allocated == 0 broken = (device isa ClimaComms.CUDADevice)
end
end
end

Expand All @@ -130,9 +137,11 @@ end

@testset "Broadcasting ifelse" begin
FT = Float32
device = ClimaComms.CPUSingleThreaded() # broken on gpu
context = ClimaComms.SingletonCommsContext(device)
for space in (
TU.CenterExtrudedFiniteDifferenceSpace(FT),
TU.ColumnCenterFiniteDifferenceSpace(FT),
TU.CenterExtrudedFiniteDifferenceSpace(FT; context),
TU.ColumnCenterFiniteDifferenceSpace(FT; context),
)
a = Fields.level(fill(FT(0), space), 1)
b = Fields.level(fill(FT(2), space), 1)
Expand Down Expand Up @@ -258,28 +267,32 @@ end
end

@testset "FieldVector array_type" begin
space = TU.PointSpace(Float32)
device = ClimaComms.device()
context = ClimaComms.SingletonCommsContext(device)
space = TU.PointSpace(Float32; context)
xcenters = Fields.coordinate_field(space).x
y = Fields.FieldVector(x = xcenters)
@test ClimaComms.array_type(y) == Array
@test ClimaComms.array_type(y) == ClimaComms.array_type(device)
y = Fields.FieldVector(x = xcenters, y = xcenters)
@test ClimaComms.array_type(y) == Array
@test ClimaComms.array_type(y) == ClimaComms.array_type(device)
end

@testset "FieldVector basetype replacement and deepcopy" begin
device = ClimaComms.CPUSingleThreaded() # constructing space_vijfh is broken
context = ClimaComms.SingletonCommsContext(device)
domain_z = Domains.IntervalDomain(
Geometry.ZPoint(-1.0) .. Geometry.ZPoint(1.0),
periodic = true,
)
mesh_z = Meshes.IntervalMesh(domain_z; nelems = 10)
topology_z = Topologies.IntervalTopology(mesh_z)
topology_z = Topologies.IntervalTopology(context, mesh_z)

domain_x = Domains.IntervalDomain(
Geometry.XPoint(-1.0) .. Geometry.XPoint(1.0),
periodic = true,
)
mesh_x = Meshes.IntervalMesh(domain_x; nelems = 10)
topology_x = Topologies.IntervalTopology(mesh_x)
topology_x = Topologies.IntervalTopology(context, mesh_x)

domain_xy = Domains.RectangleDomain(
Geometry.XPoint(-1.0) .. Geometry.XPoint(1.0),
Expand All @@ -288,8 +301,7 @@ end
x2periodic = true,
)
mesh_xy = Meshes.RectilinearMesh(domain_xy, 10, 10)
topology_xy =
Topologies.Topology2D(ClimaComms.SingletonCommsContext(), mesh_xy)
topology_xy = Topologies.Topology2D(context, mesh_xy)

quad = Spaces.Quadratures.GLL{4}()

Expand Down Expand Up @@ -403,7 +415,8 @@ end
end

@testset "PointField" begin
context = ClimaComms.SingletonCommsContext()
device = ClimaComms.CPUSingleThreaded() # a bunch of cuda pieces are broken
context = ClimaComms.SingletonCommsContext(device)
FT = Float64
coord = Geometry.XPoint(FT(π))
space = Spaces.PointSpace(context, coord)
Expand Down Expand Up @@ -547,8 +560,9 @@ Base.broadcastable(x::InferenceFoo) = Ref(x)
end
FT = Float64
foo = InferenceFoo(2.0)

for space in TU.all_spaces(FT)
device = ClimaComms.CPUSingleThreaded() # cuda fill is broken
context = ClimaComms.SingletonCommsContext(device)
for space in TU.all_spaces(FT; context)
Y = fill((; a = FT(0), b = FT(1)), space)
@test_throws ErrorException("type InferenceFoo has no field bingo") FieldFromNamedTupleBroken(
space,
Expand Down Expand Up @@ -608,26 +622,28 @@ end
)
local_geometry = Geometry.LocalGeometry(coord, FT(1.0), FT(1.0), at)
space = Spaces.PointSpace(context, local_geometry)
dz_computed = parent(Fields.Δz_field(space))
dz_computed = Array(parent(Fields.Δz_field(space)))
@test length(dz_computed) == 1
@test dz_computed[1] == expected_dz
end
end

@testset "scalar assignment" begin
device = ClimaComms.CPUSingleThreaded() # constructing space_vijfh is broken
context = ClimaComms.SingletonCommsContext(device)
domain_z = Domains.IntervalDomain(
Geometry.ZPoint(-1.0) .. Geometry.ZPoint(1.0),
periodic = true,
)
mesh_z = Meshes.IntervalMesh(domain_z; nelems = 10)
topology_z = Topologies.IntervalTopology(mesh_z)
topology_z = Topologies.IntervalTopology(context, mesh_z)

domain_x = Domains.IntervalDomain(
Geometry.XPoint(-1.0) .. Geometry.XPoint(1.0),
periodic = true,
)
mesh_x = Meshes.IntervalMesh(domain_x; nelems = 10)
topology_x = Topologies.IntervalTopology(mesh_x)
topology_x = Topologies.IntervalTopology(context, mesh_x)

domain_xy = Domains.RectangleDomain(
Geometry.XPoint(-1.0) .. Geometry.XPoint(1.0),
Expand All @@ -636,8 +652,7 @@ end
x2periodic = true,
)
mesh_xy = Meshes.RectilinearMesh(domain_xy, 10, 10)
topology_xy =
Topologies.Topology2D(ClimaComms.SingletonCommsContext(), mesh_xy)
topology_xy = Topologies.Topology2D(context, mesh_xy)

quad = Spaces.Quadratures.GLL{4}()

Expand Down Expand Up @@ -683,8 +698,10 @@ convergence_rate(err, Δh) =
col_copy = similar(y[Fields.ColumnIndex((1, 1), 1)])
return Fields.Field(Fields.field_values(col_copy), axes(col_copy))
end
device = ClimaComms.CPUSingleThreaded()
context = ClimaComms.SingletonCommsContext(device)
for zelem in (2^2, 2^3, 2^4, 2^5)
for space in TU.all_spaces(FT; zelem)
for space in TU.all_spaces(FT; zelem, context)
# Filter out spaces without z coordinates:
TU.has_z_coordinates(space) || continue
# Skip spaces incompatible with Fields.bycolumn:
Expand Down Expand Up @@ -729,7 +746,9 @@ end

@testset "Allocation tests for integrals" begin
FT = Float64
for space in TU.all_spaces(FT)
device = ClimaComms.CPUSingleThreaded()
context = ClimaComms.SingletonCommsContext(device)
for space in TU.all_spaces(FT; context)
# Filter out spaces without z coordinates:
TU.has_z_coordinates(space) || continue
Y = fill((; y = FT(1)), space)
Expand Down

0 comments on commit da868a1

Please sign in to comment.