Skip to content

Commit

Permalink
gpu fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jun 20, 2024
1 parent 59fe268 commit fa21f52
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
2 changes: 0 additions & 2 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ struct AxisTensor{
N,
A <: NTuple{N, AbstractAxis},
S <: Union{
# SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
SimpleSymmetric{N, T},
StaticArray{<:Tuple, T, N},
},
Expand All @@ -152,7 +151,6 @@ AxisTensor(
) where {
A <: Tuple{Vararg{AbstractAxis}},
S <: Union{
# SimpleSymmetric{T, <:StaticArray{<:Tuple, T, N}},
SimpleSymmetric{N, T},
StaticArray{<:Tuple, T, N},
},
Expand Down
40 changes: 30 additions & 10 deletions src/Geometry/simple_symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,36 @@ triangular_nonzeros(::SMatrix{N}) where {N} = Int(N * (N + 1) / 2)
triangular_nonzeros(::Type{<:SMatrix{N}}) where {N} = Int(N * (N + 1) / 2)
tail_params(::Type{S}) where {N,T, S<:SMatrix{N,N,T}} = (T, S, N, triangular_nonzeros(S))

function SimpleSymmetric(A::SMatrix)
@assert size(A, 1) == size(A, 2)
N = size(A, 1)
nd = ndims(A)
T = eltype(A)
ci = ntuple(i -> 1:size(A, i), nd)
upper_inds = filter(I -> I.I[2] I.I[1], CartesianIndices(ci))
L = triangular_nonzeros(A)
upper_triang = SVector{L}(map(I -> A[I], upper_inds))
SimpleSymmetric{N, T, L}(upper_triang)
# function SimpleSymmetric(A::SMatrix)
# @assert size(A, 1) == size(A, 2)
# N = size(A, 1)
# nd = ndims(A)
# T = eltype(A)
# ci = ntuple(i -> 1:size(A, i), nd)
# upper_inds = filter(I -> I.I[2] ≥ I.I[1], CartesianIndices(ci))
# L = triangular_nonzeros(A)
# upper_triang = SVector{L}(map(I -> A[I], upper_inds))
# SimpleSymmetric{N, T, L}(upper_triang)
# end

@generated function SimpleSymmetric(A::S) where {S <: SMatrix}
N = size(S, 1)
L = triangular_nonzeros(S)
_check_simple_symmetric_parameters(Val(N), Val(L))
expr = Vector{Expr}(undef, L)
T = eltype(S)
i = 0
for col in 1:N, row in 1:N
if col row
expr[i += 1] = :(A[$row, $col])
end
end
quote
Base.@_inline_meta
@inbounds return SimpleSymmetric{$N, $T, $L}(
SVector{$L, $T}(tuple($(expr...))),
)
end
end

@inline function _check_simple_symmetric_parameters(
Expand Down

0 comments on commit fa21f52

Please sign in to comment.