diff --git a/src/Grids/spectralelement.jl b/src/Grids/spectralelement.jl index 3913c52913..c880c99829 100644 --- a/src/Grids/spectralelement.jl +++ b/src/Grids/spectralelement.jl @@ -39,16 +39,24 @@ function SpectralElementGrid1D( end end -function _SpectralElementGrid1D( +_SpectralElementGrid1D( topology::Topologies.IntervalTopology, quadrature_style::Quadratures.QuadratureStyle, +) = _SpectralElementGrid1D( + topology, + quadrature_style, + Val(Topologies.nlocalelems(topology)), ) + +function _SpectralElementGrid1D( + topology::Topologies.IntervalTopology, + quadrature_style::Quadratures.QuadratureStyle, + ::Val{Nh}, +) where {Nh} global_geometry = Geometry.CartesianGlobalGeometry() CoordType = Topologies.coordinate_type(topology) AIdx = Geometry.coordinate_axis(CoordType) FT = eltype(CoordType) - nelements = Topologies.nlocalelems(topology) - Nh = nelements Nq = Quadratures.degrees_of_freedom(quadrature_style) LG = Geometry.LocalGeometry{AIdx, CoordType, FT, SMatrix{1, 1, FT, 1}} @@ -56,7 +64,7 @@ function _SpectralElementGrid1D( quad_points, quad_weights = Quadratures.quadrature_points(FT, quadrature_style) - for elem in 1:nelements + for elem in 1:Nh local_geometry_slab = slab(local_geometry, elem) for i in 1:Nq ξ = quad_points[i] @@ -182,12 +190,24 @@ function get_CoordType2D(topology) end end -function _SpectralElementGrid2D( +_SpectralElementGrid2D( topology::Topologies.Topology2D, quadrature_style::Quadratures.QuadratureStyle; enable_bubble::Bool, +) = _SpectralElementGrid2D( + topology, + quadrature_style, + Val(Topologies.nlocalelems(topology)); + enable_bubble, ) +function _SpectralElementGrid2D( + topology::Topologies.Topology2D, + quadrature_style::Quadratures.QuadratureStyle, + ::Val{Nh}; + enable_bubble::Bool, +) where {Nh} + # 1. compute localgeom for local elememts # 2. ghost exchange of localgeom # 3. do a round of dss on WJs @@ -213,8 +233,6 @@ function _SpectralElementGrid2D( end CoordType2D = get_CoordType2D(topology) AIdx = Geometry.coordinate_axis(CoordType2D) - nlelems = Topologies.nlocalelems(topology) - Nh = nlelems ngelems = Topologies.nghostelems(topology) Nq = Quadratures.degrees_of_freedom(quadrature_style) high_order_quadrature_style = Quadratures.GLL{Nq * 2}() diff --git a/test/Spaces/opt_spaces.jl b/test/Spaces/opt_spaces.jl index c2d66c8876..5e52b9004d 100644 --- a/test/Spaces/opt_spaces.jl +++ b/test/Spaces/opt_spaces.jl @@ -1,9 +1,10 @@ #= julia --project=.buildkite +ENV["CLIMACOMMS_DEVICE"] = "CUDA"; using Revise; include(joinpath("test", "Spaces", "opt_spaces.jl")) =# import ClimaCore -import ClimaCore: Spaces, Grids +import ClimaCore: Spaces, Grids, Topologies using Test include( joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"), @@ -34,19 +35,19 @@ end #! format: off if ClimaComms.device(context) isa ClimaComms.CUDADevice test_n_failures(86, TU.PointSpace, context) - test_n_failures(144, TU.SpectralElementSpace1D, context) + test_n_failures(141, TU.SpectralElementSpace1D, context) test_n_failures(1141, TU.SpectralElementSpace2D, context) - test_n_failures(123, TU.ColumnCenterFiniteDifferenceSpace, context) - test_n_failures(123, TU.ColumnFaceFiniteDifferenceSpace, context) - test_n_failures(1131, TU.SphereSpectralElementSpace, context) - test_n_failures(1139, TU.CenterExtrudedFiniteDifferenceSpace, context) - test_n_failures(1139, TU.FaceExtrudedFiniteDifferenceSpace, context) + test_n_failures(3, TU.ColumnCenterFiniteDifferenceSpace, context) + test_n_failures(4, TU.ColumnFaceFiniteDifferenceSpace, context) + test_n_failures(1147, TU.SphereSpectralElementSpace, context) + test_n_failures(1146, TU.CenterExtrudedFiniteDifferenceSpace, context) + test_n_failures(1146, TU.FaceExtrudedFiniteDifferenceSpace, context) else test_n_failures(0, TU.PointSpace, context) test_n_failures(137, TU.SpectralElementSpace1D, context) test_n_failures(310, TU.SpectralElementSpace2D, context) - test_n_failures(118, TU.ColumnCenterFiniteDifferenceSpace, context) - test_n_failures(118, TU.ColumnFaceFiniteDifferenceSpace, context) + test_n_failures(4, TU.ColumnCenterFiniteDifferenceSpace, context) + test_n_failures(5, TU.ColumnFaceFiniteDifferenceSpace, context) test_n_failures(316, TU.SphereSpectralElementSpace, context) test_n_failures(321, TU.CenterExtrudedFiniteDifferenceSpace, context) test_n_failures(321, TU.FaceExtrudedFiniteDifferenceSpace, context) @@ -56,11 +57,10 @@ end # separately: space = TU.CenterExtrudedFiniteDifferenceSpace(Float32; context=ClimaComms.context()) - # @test_opt Grids._SpectralElementGrid2D(Spaces.topology(space), Spaces.quadrature_style(space); enable_bubble=false) - - result = JET.@report_opt Grids._SpectralElementGrid2D(Spaces.topology(space), Spaces.quadrature_style(space); enable_bubble=false) + Nh = Val(Topologies.nlocalelems(Spaces.topology(space))) + result = JET.@report_opt Grids._SpectralElementGrid2D(Spaces.topology(space), Spaces.quadrature_style(space), Val(Nh); enable_bubble=false) n_found = length(JET.get_reports(result.analyzer, result.result)) - n_allowed = 187 + n_allowed = 0 @test n_found ≤ n_allowed if n_found < n_allowed @info "Inference may have improved for _SpectralElementGrid2D: (n_found, n_allowed) = ($n_found, $n_allowed)"