|
1 | 1 | module QuaternionicChainRulesCoreExt
|
2 | 2 |
|
3 |
| -using Pkg |
4 | 3 | using Quaternionic
|
5 | 4 | import Quaternionic: _sincu, _cossu
|
6 | 5 | using StaticArrays
|
7 | 6 | isdefined(Base, :get_extension) ?
|
8 | 7 | (using ChainRulesCore; import ChainRulesCore: rrule, rrule_via_ad, RuleConfig, ProjectTo) :
|
9 | 8 | (using ..ChainRulesCore; import ...ChainRulesCore: rrule, rrule_via_ad, RuleConfig, ProjectTo)
|
10 | 9 |
|
11 |
| - |
12 |
| -## StaticArrays |
13 |
| -# It's likely that StaticArrays will have its own ChainRulesCore extension someday, so we |
14 |
| -# need to check if there is already a ProjectTo defined for SArray. If so, we'll use that. |
15 |
| -# If not, we'll define one here. |
16 |
| -staticarrays_info = Pkg.dependencies()[Base.UUID("90137ffa-7385-5640-81b9-e52037218182")] |
17 |
| -if staticarrays_info.version < v"1.8.1" |
18 |
| - # These are ripped from https://github.com/JuliaArrays/StaticArrays.jl/pull/1068 |
19 |
| - function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) |
20 |
| - dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray |
21 |
| - dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) |
22 |
| - return ChainRulesCore.project_type(project)(dz...) |
23 |
| - end |
24 |
| - function ProjectTo(x::SArray{S,T}) where {S, T} |
25 |
| - return ProjectTo{SArray}(; |
26 |
| - element=ChainRulesCore._eltype_projectto(T), |
27 |
| - axes=axes(x), size=StaticArrays.Size(x) |
28 |
| - ) |
29 |
| - end |
30 |
| - @inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx) |
31 |
| - (project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.size, dx) |
32 |
| - function rrule(::Type{T}, x::Tuple) where {T <: SArray} |
33 |
| - project_x = ProjectTo(x) |
34 |
| - ∇Array(∂y) = (NoTangent(), project_x(∂y)) |
35 |
| - return T(x), ∇Array |
36 |
| - end |
37 |
| - function rrule(::Type{T}, xs::Number...) where {T <: SVector} |
38 |
| - project_x = ProjectTo(xs) |
39 |
| - ∇Array(∂y) = (NoTangent(), project_x(∂y)...) |
40 |
| - return T(xs...), ∇Array |
41 |
| - end |
42 |
| -end |
43 |
| - |
44 |
| - |
45 | 10 | function rrule(::Type{QT}, arg::AbstractVector) where {QT<:AbstractQuaternion}
|
46 | 11 | AbstractQuaternion_pullback(Δquat) = (NoTangent(), components(unthunk(Δquat)))
|
47 | 12 | return QT(arg), AbstractQuaternion_pullback
|
|
0 commit comments