Skip to content

Commit

Permalink
Improve inference in grid constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Aug 17, 2024
1 parent 2f26e9a commit 5cd7a71
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 20 deletions.
32 changes: 25 additions & 7 deletions src/Grids/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,32 @@ function SpectralElementGrid1D(
end
end

function _SpectralElementGrid1D(
_SpectralElementGrid1D(
topology::Topologies.IntervalTopology,
quadrature_style::Quadratures.QuadratureStyle,
) = _SpectralElementGrid1D(
topology,
quadrature_style,
Val(Topologies.nlocalelems(topology)),
)

function _SpectralElementGrid1D(
topology::Topologies.IntervalTopology,
quadrature_style::Quadratures.QuadratureStyle,
::Val{Nh},
) where {Nh}
global_geometry = Geometry.CartesianGlobalGeometry()
CoordType = Topologies.coordinate_type(topology)
AIdx = Geometry.coordinate_axis(CoordType)
FT = eltype(CoordType)
nelements = Topologies.nlocalelems(topology)
Nh = nelements
Nq = Quadratures.degrees_of_freedom(quadrature_style)

LG = Geometry.LocalGeometry{AIdx, CoordType, FT, SMatrix{1, 1, FT, 1}}
local_geometry = DataLayouts.IFH{LG, Nq, Nh}(Array{FT})
quad_points, quad_weights =
Quadratures.quadrature_points(FT, quadrature_style)

for elem in 1:nelements
for elem in 1:Nh
local_geometry_slab = slab(local_geometry, elem)
for i in 1:Nq
ξ = quad_points[i]
Expand Down Expand Up @@ -182,12 +190,24 @@ function get_CoordType2D(topology)
end
end

function _SpectralElementGrid2D(
_SpectralElementGrid2D(
topology::Topologies.Topology2D,
quadrature_style::Quadratures.QuadratureStyle;
enable_bubble::Bool,
) = _SpectralElementGrid2D(
topology,
quadrature_style,
Val(Topologies.nlocalelems(topology));
enable_bubble,
)

function _SpectralElementGrid2D(
topology::Topologies.Topology2D,
quadrature_style::Quadratures.QuadratureStyle,
::Val{Nh};
enable_bubble::Bool,
) where {Nh}

# 1. compute localgeom for local elememts
# 2. ghost exchange of localgeom
# 3. do a round of dss on WJs
Expand All @@ -213,8 +233,6 @@ function _SpectralElementGrid2D(
end
CoordType2D = get_CoordType2D(topology)
AIdx = Geometry.coordinate_axis(CoordType2D)
nlelems = Topologies.nlocalelems(topology)
Nh = nlelems
ngelems = Topologies.nghostelems(topology)
Nq = Quadratures.degrees_of_freedom(quadrature_style)
high_order_quadrature_style = Quadratures.GLL{Nq * 2}()
Expand Down
26 changes: 13 additions & 13 deletions test/Spaces/opt_spaces.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#=
julia --project=.buildkite
ENV["CLIMACOMMS_DEVICE"] = "CUDA";
using Revise; include(joinpath("test", "Spaces", "opt_spaces.jl"))
=#
import ClimaCore
import ClimaCore: Spaces, Grids
import ClimaCore: Spaces, Grids, Topologies
using Test
include(
joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"),
Expand Down Expand Up @@ -34,19 +35,19 @@ end
#! format: off
if ClimaComms.device(context) isa ClimaComms.CUDADevice
test_n_failures(86, TU.PointSpace, context)
test_n_failures(144, TU.SpectralElementSpace1D, context)
test_n_failures(141, TU.SpectralElementSpace1D, context)
test_n_failures(1141, TU.SpectralElementSpace2D, context)
test_n_failures(123, TU.ColumnCenterFiniteDifferenceSpace, context)
test_n_failures(123, TU.ColumnFaceFiniteDifferenceSpace, context)
test_n_failures(1131, TU.SphereSpectralElementSpace, context)
test_n_failures(1139, TU.CenterExtrudedFiniteDifferenceSpace, context)
test_n_failures(1139, TU.FaceExtrudedFiniteDifferenceSpace, context)
test_n_failures(3, TU.ColumnCenterFiniteDifferenceSpace, context)
test_n_failures(4, TU.ColumnFaceFiniteDifferenceSpace, context)
test_n_failures(1147, TU.SphereSpectralElementSpace, context)
test_n_failures(1146, TU.CenterExtrudedFiniteDifferenceSpace, context)
test_n_failures(1146, TU.FaceExtrudedFiniteDifferenceSpace, context)
else
test_n_failures(0, TU.PointSpace, context)
test_n_failures(137, TU.SpectralElementSpace1D, context)
test_n_failures(310, TU.SpectralElementSpace2D, context)
test_n_failures(118, TU.ColumnCenterFiniteDifferenceSpace, context)
test_n_failures(118, TU.ColumnFaceFiniteDifferenceSpace, context)
test_n_failures(4, TU.ColumnCenterFiniteDifferenceSpace, context)
test_n_failures(5, TU.ColumnFaceFiniteDifferenceSpace, context)
test_n_failures(316, TU.SphereSpectralElementSpace, context)
test_n_failures(321, TU.CenterExtrudedFiniteDifferenceSpace, context)
test_n_failures(321, TU.FaceExtrudedFiniteDifferenceSpace, context)
Expand All @@ -56,11 +57,10 @@ end
# separately:

space = TU.CenterExtrudedFiniteDifferenceSpace(Float32; context=ClimaComms.context())
# @test_opt Grids._SpectralElementGrid2D(Spaces.topology(space), Spaces.quadrature_style(space); enable_bubble=false)

result = JET.@report_opt Grids._SpectralElementGrid2D(Spaces.topology(space), Spaces.quadrature_style(space); enable_bubble=false)
Nh = Val(Topologies.nlocalelems(Spaces.topology(space)))
result = JET.@report_opt Grids._SpectralElementGrid2D(Spaces.topology(space), Spaces.quadrature_style(space), Val(Nh); enable_bubble=false)
n_found = length(JET.get_reports(result.analyzer, result.result))
n_allowed = 187
n_allowed = 0
@test n_found n_allowed
if n_found < n_allowed
@info "Inference may have improved for _SpectralElementGrid2D: (n_found, n_allowed) = ($n_found, $n_allowed)"
Expand Down

0 comments on commit 5cd7a71

Please sign in to comment.