From fa21f528e3e706d6ddc64ba4a251d2c930fe8072 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Thu, 20 Jun 2024 17:02:48 -0400 Subject: [PATCH] gpu fixes --- src/Geometry/axistensors.jl | 2 -- src/Geometry/simple_symmetric.jl | 40 ++++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/src/Geometry/axistensors.jl b/src/Geometry/axistensors.jl index 50ea64e5bd..bf225e9b2b 100644 --- a/src/Geometry/axistensors.jl +++ b/src/Geometry/axistensors.jl @@ -137,7 +137,6 @@ struct AxisTensor{ N, A <: NTuple{N, AbstractAxis}, S <: Union{ - # SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}}, SimpleSymmetric{N, T}, StaticArray{<:Tuple, T, N}, }, @@ -152,7 +151,6 @@ AxisTensor( ) where { A <: Tuple{Vararg{AbstractAxis}}, S <: Union{ - # SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}}, SimpleSymmetric{N, T}, StaticArray{<:Tuple, T, N}, }, diff --git a/src/Geometry/simple_symmetric.jl b/src/Geometry/simple_symmetric.jl index 202161d126..755af73f46 100644 --- a/src/Geometry/simple_symmetric.jl +++ b/src/Geometry/simple_symmetric.jl @@ -55,16 +55,36 @@ triangular_nonzeros(::SMatrix{N}) where {N} = Int(N * (N + 1) / 2) triangular_nonzeros(::Type{<:SMatrix{N}}) where {N} = Int(N * (N + 1) / 2) tail_params(::Type{S}) where {N,T, S<:SMatrix{N,N,T}} = (T, S, N, triangular_nonzeros(S)) -function SimpleSymmetric(A::SMatrix) - @assert size(A, 1) == size(A, 2) - N = size(A, 1) - nd = ndims(A) - T = eltype(A) - ci = ntuple(i -> 1:size(A, i), nd) - upper_inds = filter(I -> I.I[2] ≥ I.I[1], CartesianIndices(ci)) - L = triangular_nonzeros(A) - upper_triang = SVector{L}(map(I -> A[I], upper_inds)) - SimpleSymmetric{N, T, L}(upper_triang) +# function SimpleSymmetric(A::SMatrix) +# @assert size(A, 1) == size(A, 2) +# N = size(A, 1) +# nd = ndims(A) +# T = eltype(A) +# ci = ntuple(i -> 1:size(A, i), nd) +# upper_inds = filter(I -> I.I[2] ≥ I.I[1], CartesianIndices(ci)) +# L = triangular_nonzeros(A) +# upper_triang = SVector{L}(map(I -> A[I], upper_inds)) +# SimpleSymmetric{N, T, L}(upper_triang) +# end + +@generated function SimpleSymmetric(A::S) where {S <: SMatrix} + N = size(S, 1) + L = triangular_nonzeros(S) + _check_simple_symmetric_parameters(Val(N), Val(L)) + expr = Vector{Expr}(undef, L) + T = eltype(S) + i = 0 + for col in 1:N, row in 1:N + if col ≥ row + expr[i += 1] = :(A[$row, $col]) + end + end + quote + Base.@_inline_meta + @inbounds return SimpleSymmetric{$N, $T, $L}( + SVector{$L, $T}(tuple($(expr...))), + ) + end end @inline function _check_simple_symmetric_parameters(