Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jun 20, 2024
1 parent 6fa5586 commit 59fe268
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 17 deletions.
6 changes: 4 additions & 2 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ struct AxisTensor{
N,
A <: NTuple{N, AbstractAxis},
S <: Union{
SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
# SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
SimpleSymmetric{N, T},
StaticArray{<:Tuple, T, N},
},
} <: AbstractArray{T, N}
Expand All @@ -151,7 +152,8 @@ AxisTensor(
) where {
A <: Tuple{Vararg{AbstractAxis}},
S <: Union{
SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
# SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
SimpleSymmetric{N, T},
StaticArray{<:Tuple, T, N},
},
} where {T, N} = AxisTensor{T, N, A, S}(axes, components)
Expand Down
13 changes: 6 additions & 7 deletions src/Geometry/localgeometry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end
The necessary local metric information defined at each node.
"""
struct LocalGeometry{I, C <: AbstractPoint, FT, S, L}
struct LocalGeometry{I, C <: AbstractPoint, FT, S, N, L}
"Coordinates of the current point"
coordinates::C
"Jacobian determinant of the transformation `ξ` to `x`"
Expand All @@ -35,15 +35,13 @@ struct LocalGeometry{I, C <: AbstractPoint, FT, S, L}
gⁱʲ::Axis2Tensor{
FT,
Tuple{ContravariantAxis{I}, ContravariantAxis{I}},
# SimpleSymmetric{FT, S},
SimpleSymmetric{2, FT, L},
SimpleSymmetric{N, FT, L},
}
"Covariant metric tensor (gᵢⱼ), transforms contravariant to covariant vector components"
gᵢⱼ::Axis2Tensor{
FT,
Tuple{CovariantAxis{I}, CovariantAxis{I}},
# SimpleSymmetric{FT, S},
SimpleSymmetric{2, FT, L},
SimpleSymmetric{N, FT, L},
}
@inline function LocalGeometry(
coordinates,
Expand All @@ -61,7 +59,8 @@ struct LocalGeometry{I, C <: AbstractPoint, FT, S, L}
gⁱʲ = SimpleSymmetric(gⁱʲ₀)
gᵢⱼ = SimpleSymmetric(gᵢⱼ₀)
L = triangular_nonzeros(S)
return new{I, C, FT, S, L}(coordinates, J, WJ, Jinv, ∂x∂ξ, ∂ξ∂x, gⁱʲ, gᵢⱼ)
N = size(components(gⁱʲ₀), 1)
return new{I, C, FT, S, N, L}(coordinates, J, WJ, Jinv, ∂x∂ξ, ∂ξ∂x, gⁱʲ, gᵢⱼ)
end
end

Expand All @@ -77,7 +76,7 @@ struct SurfaceGeometry{FT, N}
normal::N
end

undertype(::Type{LocalGeometry{I, C, FT, S}}) where {I, C, FT, S} = FT
undertype(::Type{<:LocalGeometry{I, C, FT}}) where {I, C, FT} = FT
undertype(::Type{SurfaceGeometry{FT, N}}) where {FT, N} = FT

"""
Expand Down
7 changes: 4 additions & 3 deletions src/Geometry/simple_symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ StaticArrays.check_parameters(

triangular_nonzeros(::SMatrix{N}) where {N} = Int(N * (N + 1) / 2)
triangular_nonzeros(::Type{<:SMatrix{N}}) where {N} = Int(N * (N + 1) / 2)
tail_params(::Type{S}) where {N,T, S<:SMatrix{N,N,T}} = (T, S, N, triangular_nonzeros(S))

function SimpleSymmetric(A::SMatrix)
@assert size(A, 1) == size(A, 2)
M = size(A, 1)
N = ndims(A)
N = size(A, 1)
nd = ndims(A)
T = eltype(A)
ci = ntuple(i -> 1:size(A, i), N)
ci = ntuple(i -> 1:size(A, i), nd)
upper_inds = filter(I -> I.I[2] I.I[1], CartesianIndices(ci))
L = triangular_nonzeros(A)
upper_triang = SVector{L}(map(I -> A[I], upper_inds))
Expand Down
3 changes: 2 additions & 1 deletion src/Grids/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ function fd_geometry_data(
) where {FT, periodic}
CT = Geometry.ZPoint{FT}
AIdx = (3,)
LG = Geometry.LocalGeometry{AIdx, CT, FT, SMatrix{1, 1, FT, 1}}
S = SMatrix{1, 1, FT, 1}
LG = Geometry.LocalGeometry{AIdx, CT, Geometry.tail_params(S)...}
(Ni, Nj, Nk, Nv, Nh) = size(face_coordinates)
Nv_face = Nv - periodic
Nv_cent = Nv - 1
Expand Down
6 changes: 4 additions & 2 deletions src/Grids/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ function _SpectralElementGrid1D(
nelements = Topologies.nlocalelems(topology)
Nq = Quadratures.degrees_of_freedom(quadrature_style)

LG = Geometry.LocalGeometry{AIdx, CoordType, FT, SMatrix{1, 1, FT, 1}}
S = SMatrix{1, 1, FT, 1}
LG = Geometry.LocalGeometry{AIdx, CoordType, Geometry.tail_params(S)...}
local_geometry = DataLayouts.IFH{LG, Nq}(Array{FT}, nelements)
quad_points, quad_weights =
Quadratures.quadrature_points(FT, quadrature_style)
Expand Down Expand Up @@ -218,7 +219,8 @@ function _SpectralElementGrid2D(
high_order_quadrature_style = Quadratures.GLL{Nq * 2}()
high_order_Nq = Quadratures.degrees_of_freedom(high_order_quadrature_style)

LG = Geometry.LocalGeometry{AIdx, CoordType2D, FT, SMatrix{2, 2, FT, 4}}
S = SMatrix{2, 2, FT, 4}
LG = Geometry.LocalGeometry{AIdx, CoordType2D, Geometry.tail_params(S)...}

local_geometry = DataLayouts.IJFH{LG, Nq}(Array{FT}, nlelems)

Expand Down
3 changes: 2 additions & 1 deletion test/DataLayouts/opt_similar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ function test_similar!(data)
FT = eltype(parent(data))
CT = Geometry.ZPoint{FT}
AIdx = (3,)
LG = Geometry.LocalGeometry{AIdx, CT, FT, SMatrix{1, 1, FT, 1}}
S = SMatrix{1, 1, FT, 1}
LG = Geometry.LocalGeometry{AIdx, CT, Geometry.tail_params(S)...}
(_, _, _, Nv, _) = size(data)
similar(data, LG, Val(Nv))
@test_opt similar(data, LG, Val(Nv))
Expand Down
11 changes: 10 additions & 1 deletion test/Geometry/axistensor_conversion_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ function benchmark_func(args, key, f, flops, ::Type{FT}; print_method_info) wher
print("Time (opt, ref): ($(opt.t_pretty), $(ref.t_pretty)). Key: $key_str\n")
# end
end
correctness = compare(components(opt.result), components(ref.result)) # test correctness
# @show correctness
@show components(opt.result)
@show components(ref.result)
@show correctness
bm = (;
opt,
ref,
Expand All @@ -121,7 +126,7 @@ function benchmark_func(args, key, f, flops, ::Type{FT}; print_method_info) wher
flops, # current flops
computed_flops,
reduced_flops,
correctness = compare(opt.result, ref.result), # test correctness
correctness, # test correctness
perf_pass = (opt.time - ref.time)/ref.time*100 < -100, # test performance
)
return bm
Expand Down Expand Up @@ -149,6 +154,7 @@ components(x::T) where {T <: Real} = x
components(x) = Geometry.components(x)
compare(x::T, y::T) where {T<: Real} = x y || (x < eps(T)/100 && y < eps(T)/100)
compare(x::T, y::T) where {T <: SMatrix} = all(compare.(x, y))
compare(x::T, y::T) where {T <: Geometry.SimpleSymmetric} = all(compare.(x.lowertriangle, y.lowertriangle))
compare(x::T, y::T) where {T <: SVector} = all(compare.(x, y))
compare(x::T, y::T) where {T <: AxisTensor} = compare(components(x), components(y))

Expand All @@ -168,6 +174,9 @@ function test_optimized_functions(::Type{FT}; print_method_info=false) where {FT
end

for key in keys(benchmarks)
if !(benchmarks[key].correctness)
@show key
end
@test benchmarks[key].correctness # test correctness
@test benchmarks[key].Δflops 0 # Don't regress
# @test_broken benchmarks[key].Δflops < 0 # Error on improvements. TODO: fix, this is somehow flakey
Expand Down
8 changes: 8 additions & 0 deletions test/Geometry/unit_simple_symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Revise; include(joinpath("test", "Geometry", "unit_simple_symmetric.jl"))
using Test
using StaticArrays
using ClimaCore.Geometry: SimpleSymmetric
import ClimaCore.Geometry
using JET
simple_symmetric(A::Matrix) = SimpleSymmetric(SMatrix{size(A)..., eltype(A)}(A))

Expand All @@ -27,4 +28,11 @@ simple_symmetric(A::Matrix) = SimpleSymmetric(SMatrix{size(A)..., eltype(A)}(A))
A = @SMatrix [1 2; 2 4]
@test SimpleSymmetric(A) / 2 === SimpleSymmetric(A / 2)
@test_opt SimpleSymmetric(A)
@test Geometry.tail_params(typeof(@SMatrix Float32[1 2; 2 4])) == (Float32, SMatrix{2, 2, Float32, 4}, 2, 3)
end

@testset "sizs" begin
for N in (1,2,3,5,8,10)
simple_symmetric(rand(N,N)) # pass in non-symmetric matrix
end
end

0 comments on commit 59fe268

Please sign in to comment.