@@ -8,40 +8,6 @@ isdefined(Base, :get_extension) ?
8
8
(using ChainRulesCore; import ChainRulesCore: rrule, rrule_via_ad, RuleConfig, ProjectTo) :
9
9
(using .. ChainRulesCore; import ... ChainRulesCore: rrule, rrule_via_ad, RuleConfig, ProjectTo)
10
10
11
-
12
- # # StaticArrays
13
- # It's likely that StaticArrays will have its own ChainRulesCore extension someday, so we
14
- # need to check if there is already a ProjectTo defined for SArray. If so, we'll use that.
15
- # If not, we'll define one here.
16
- staticarrays_info = Pkg. dependencies ()[Base. UUID (" 90137ffa-7385-5640-81b9-e52037218182" )]
17
- if staticarrays_info. version < v " 1.8.1"
18
- # These are ripped from https://github.com/JuliaArrays/StaticArrays.jl/pull/1068
19
- function (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: SArray )
20
- dy = reshape (dx, axes (project. elements)) # allows for dx::OffsetArray
21
- dz = ntuple (i -> project. elements[i](dy[i]), length (project. elements))
22
- return ChainRulesCore. project_type (project)(dz... )
23
- end
24
- function ProjectTo (x:: SArray{S,T} ) where {S, T}
25
- return ProjectTo {SArray} (;
26
- element= ChainRulesCore. _eltype_projectto (T),
27
- axes= axes (x), size= StaticArrays. Size (x)
28
- )
29
- end
30
- @inline _sarray_from_array (:: Size{T} , dx:: AbstractArray ) where {T} = SArray {Tuple{T...}} (dx)
31
- (project:: ProjectTo{SArray} )(dx:: AbstractArray ) = _sarray_from_array (project. size, dx)
32
- function rrule (:: Type{T} , x:: Tuple ) where {T <: SArray }
33
- project_x = ProjectTo (x)
34
- ∇Array (∂y) = (NoTangent (), project_x (∂y))
35
- return T (x), ∇Array
36
- end
37
- function rrule (:: Type{T} , xs:: Number... ) where {T <: SVector }
38
- project_x = ProjectTo (xs)
39
- ∇Array (∂y) = (NoTangent (), project_x (∂y)... )
40
- return T (xs... ), ∇Array
41
- end
42
- end
43
-
44
-
45
11
function rrule (:: Type{QT} , arg:: AbstractVector ) where {QT<: AbstractQuaternion }
46
12
AbstractQuaternion_pullback (Δquat) = (NoTangent (), components (unthunk (Δquat)))
47
13
return QT (arg), AbstractQuaternion_pullback
0 commit comments