Skip to content

Commit

Permalink
Add FastDifferentiation support (minimal testing)
Browse files Browse the repository at this point in the history
  • Loading branch information
moble committed Jun 24, 2024
1 parent 9ff6c4c commit 93c2e5a
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 37 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
ChainRulesCore = "1"
FastDifferentiation = "0.3.14"
ForwardDiff = "0.10"
GenericLinearAlgebra = "0.3.11"
LaTeXStrings = "1"
Expand All @@ -28,11 +29,13 @@ julia = "1.6"

[extensions]
QuaternionicChainRulesCoreExt = "ChainRulesCore"
QuaternionicFastDifferentiationExt = "FastDifferentiation"
QuaternionicForwardDiffExt = "ForwardDiff"
QuaternionicSymbolicsExt = "Symbolics"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
181 changes: 181 additions & 0 deletions ext/QuaternionicFastDifferentiationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
module QuaternionicFastDifferentiationExt

using StaticArrays: SVector
using Latexify: latexify
import Quaternionic: AbstractQuaternion, Quaternion, Rotor, QuatVec,
quaternion, rotor, quatvec,
QuatVecF64, RotorF64, QuaternionF64,
wrapper, components, _pm_ascii, _pm_latex
using PrecompileTools
isdefined(Base, :get_extension) ? (using FastDifferentiation) : (using ..FastDifferentiation)


### TYPE PIRACY!!!
# The following Base functions should be added to FastDifferentiation itself; meanwhile
# we'll define them here.

# This is a workaround that should be fixed in FastDifferentiation.jl; Elaine will probably
# make a PR to fix this. The problem is that FD only defines
# `Base.promote_rule(::Type{<:Real}, ::Type{Node})`. But julia/src/bool.jl defines
# `promote_rule(::Type{Bool}, ::Type{T}) where T<:Number` and `Node <: Number`, so there is
# an ambiguity.
Base.promote_rule(::Type{Bool}, ::Type{FastDifferentiation.Node}) = FastDifferentiation.Node

# These are essentially copied from Symbolics.jl:
# https://github.com/JuliaSymbolics/Symbolics.jl/blob/e4c328103ece494eaaab2a265524a64bfbe43dbd/src/num.jl#L31-L34
Base.eps(::Type{FastDifferentiation.Node}) = FastDifferentiation.Node(0)
Base.typemin(::Type{FastDifferentiation.Node}) = FastDifferentiation.Node(-Inf)
Base.typemax(::Type{FastDifferentiation.Node}) = FastDifferentiation.Node(Inf)
Base.float(x::FastDifferentiation.Node) = x

# This one is needed because julia/base/float.jl only defines `isinf` for `Real`, but `Node
# <: Number`. (See https://github.com/brianguenter/FastDifferentiation.jl/issues/73)
Base.isinf(x::FastDifferentiation.Node) = !isnan(x) & !isfinite(x)


normalize(v::AbstractVector{FastDifferentiation.Node}) = v ./ sum(x->x^2, v)


### Functions that used to appear in quaternion.jl
quaternion(w::FastDifferentiation.Node) = quaternion(SVector{4}(w, false, false, false))
rotor(w::FastDifferentiation.Node) = rotor(SVector{4}(one(w), false, false, false))
quatvec(w::FastDifferentiation.Node) = quatvec(SVector{4,typeof(w)}(false, false, false, false))
for QT1 (AbstractQuaternion, Quaternion, QuatVec, Rotor)
@eval begin
wrapper(::Type{<:$QT1}, ::Val{OP}, ::Type{<:FastDifferentiation.Node}) where {OP} = quaternion
wrapper(::Type{<:FastDifferentiation.Node}, ::Val{OP}, ::Type{<:$QT1}) where {OP} = quaternion
end
end
let NT = FastDifferentiation.Node
for QT (QuatVec,)
for OP (Val{*}, Val{/})
@eval begin
wrapper(::Type{<:$QT}, ::$OP, ::Type{<:$NT}) = quatvec
wrapper(::Type{<:$NT}, ::$OP, ::Type{<:$QT}) = quatvec
end
end
end
for QT (Rotor,)
for OP (Val{+}, Val{-}, Val{*}, Val{/})
@eval begin
wrapper(::Type{<:$QT}, ::$OP, ::Type{<:$NT}) = quaternion
wrapper(::Type{<:$NT}, ::$OP, ::Type{<:$QT}) = quaternion
end
end
end
end
let T = FastDifferentiation.Node
for OP (Val{+}, Val{-}, Val{*}, Val{/})
@eval wrapper(::Type{<:Quaternion}, ::$OP, ::Type{<:$T}) = quaternion
if T !== Quaternion
@eval wrapper(::Type{<:$T}, ::$OP, ::Type{<:Quaternion}) = quaternion
end
end
end
Base.promote_rule(::Type{Q}, ::Type{S}) where {Q<:AbstractQuaternion,S<:FastDifferentiation.Node} =
wrapper(Q){promote_type(eltype(Q), S)}


# function _pm_ascii(x::FastDifferentiation.Node)
# # Utility function to print a component of a quaternion
# s = "$x"
# if s[1] ∉ "+-"
# s = "+" * s
# end
# if occursin(r"[+^/-]", s[2:end])
# if s[1] == '+'
# s = " + " * "(" * s[2:end] * ")"
# else
# s = " + " * "(" * s * ")"
# end
# else
# s = " " * s[1] * " " * s[2:end]
# end
# s
# end
# function _pm_latex(x::Num)
# # Utility function to print a component of a quaternion in LaTeX
# s = latexify(x, env=:raw, bracket=true)
# if s[1] ∉ "+-"
# s = "+" * s
# end
# if occursin(r"[+^/-]", s[2:end])
# if s[1] == '+'
# s = " + " * "\\left(" * s[2:end] * "\\right)"
# else
# s = " + " * "\\left(" * s * "\\right)"
# end
# else
# s = " " * s[1] * " " * s[2:end]
# end
# s
# end


# # Broadcast-like operations from FastDifferentiation
# # (d::FastDifferentiation.Operator)(q::QT) where {QT<:AbstractQuaternion} = QT(d(q[1]), d(q[2]), d(q[3]), d(q[4]))
# # (d::FastDifferentiation.Operator)(q::QuatVec) = quatvec(d(q[2]), d(q[3]), d(q[4]))
# (d::FastDifferentiation.Differential)(q::Quaternion) = quaternion(d(q[1]), d(q[2]), d(q[3]), d(q[4]))
# (d::FastDifferentiation.Differential)(q::Rotor) = quaternion(d(q[1]), d(q[2]), d(q[3]), d(q[4]))
# (d::FastDifferentiation.Differential)(q::QuatVec) = quatvec(d(q[2]), d(q[3]), d(q[4]))


### Functions that used to appear in algebra.jl
for TA (AbstractQuaternion, Rotor, QuatVec)
let TB = FastDifferentiation.Node
@eval begin
Base.:+(q::QT, p::$TB) where {QT<:$TA} = wrapper($TA, Val(+), $TB)(q[1]+p, q[2], q[3], q[4])
Base.:-(q::QT, p::$TB) where {QT<:$TA} = wrapper($TA, Val(-), $TB)(q[1]-p, q[2], q[3], q[4])
Base.:+(p::$TB, q::QT) where {QT<:$TA} = wrapper($TB, Val(+), $TA)(p+q[1], q[2], q[3], q[4])
Base.:-(p::$TB, q::QT) where {QT<:$TA} = wrapper($TB, Val(-), $TA)(p-q[1], -q[2], -q[3], -q[4])
end
end
end
let S = FastDifferentiation.Node
@eval begin
Base.:*(p::Q, s::$S) where {Q<:AbstractQuaternion} = wrapper(Q, Val(*), $S)(s*components(p))
Base.:*(s::$S, p::Q) where {Q<:AbstractQuaternion} = wrapper($S, Val(*), Q)(s*components(p))
Base.:/(p::Q, s::$S) where {Q<:AbstractQuaternion} = wrapper(Q, Val(/), $S)(components(p)/s)
function Base.:/(s::$S, p::Q) where {Q<:AbstractQuaternion}
f = s / abs2(p)
wrapper($S, Val(/), Q)(p[1] * f, -p[2] * f, -p[3] * f, -p[4] * f)
end
end
end


# Pre-compilation

@setup_workload begin
# Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the
# size of the precompile file and potentially make loading faster.
FastDifferentiation.@variables w x y z a b c d e
s = randn(Float64)
v = randn(QuatVecF64)
r = randn(RotorF64)
q = randn(QuaternionF64)
𝓈 = w
𝓋 = quatvec(x, y, z)
𝓇 = rotor(a, b, c, d)
𝓆 = quaternion(w, x, y, z)

@compile_workload begin
# all calls in this block will be precompiled, regardless of whether they belong to
# this package or not (on Julia 1.8 and higher)
r(v)
𝓇(𝓋)
for a [s, v, r, q, 𝓈, 𝓋, 𝓇, 𝓆]
conj(a)
for b [s, v, r, q, 𝓈, 𝓋, 𝓇, 𝓆]
a * b
a / b
a + b
a - b
end
end

end
end


end # module
4 changes: 3 additions & 1 deletion ext/QuaternionicSymbolicsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using PrecompileTools
isdefined(Base, :get_extension) ? (using Symbolics) : (using ..Symbolics)


normalize(v::AbstractVector{Symbolics.Num}) = v ./ sum(x->x^2, v)

### Functions that used to appear in quaternion.jl
quaternion(w::Symbolics.Num) = quaternion(SVector{4}(w, false, false, false))
rotor(w::Symbolics.Num) = rotor(SVector{4}(one(w), false, false, false))
Expand Down Expand Up @@ -257,7 +259,7 @@ end

@compile_workload begin
# all calls in this block will be precompiled, regardless of whether they belong to
# your package or not (on Julia 1.8 and higher)
# this package or not (on Julia 1.8 and higher)
r(v)
Symbolics.simplify.(𝓇(𝓋))
for a [s, v, r, q, 𝓈, 𝓋, 𝓇, 𝓆]
Expand Down
1 change: 1 addition & 0 deletions src/Quaternionic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ end

function __init__()
@require ChainRulesCore="d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include("../ext/QuaternionicChainRulesCoreExt.jl")
@require FastDifferentiation="eb9bf01b-bf85-4b60-bf87-ee5de06c00be" include("../ext/QuaternionicFastDifferentiationExt.jl")
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("../ext/QuaternionicForwardDiffExt.jl")
@require Symbolics="0c5d862f-8b57-4792-8d23-62f2024744c7" include("../ext/QuaternionicSymbolicsExt.jl")
end
Expand Down
2 changes: 1 addition & 1 deletion src/examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ julia> R, ω⃗, Ṙ = precessing_nutating_example();
julia> R(12.34)
rotor(0.9944579779058746 + 0.09804177421238346𝐢 - 0.0008485045352531196𝐣 + 0.03795287510453948𝐤)
julia> ω⃗(345.67)
+ 0.0004634300734286701𝐢 - 0.0007032818419003175𝐣 + 0.006214814810035088𝐤
+ 0.00046343007342866996𝐢 - 0.0007032818419003173𝐣 + 0.006214814810035087𝐤
julia> ϵ = 1e-6; (R(ϵ) - R(-ϵ)) / 2ϵ # Approximate derivative at t=0
-3.8491432263754177e-7 + (3.9080960689830135e-6)𝐢 - (6.861695854245622e-5)𝐣 + 0.003076329202503836𝐤
julia> Ṙ(0)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
4 changes: 2 additions & 2 deletions test/algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ end
chars = Iterators.Stateful(Iterators.cycle("abcdefghijkl"))
function next_scalar!(chars)
x = Symbol(popfirst!(chars))
xvar = @variables $x
xvar = Symbolics.@variables $x
xvar[1]
end
function next_quaternion!(chars)
x = Symbol(popfirst!(chars))
# May have to work around <https://github.com/JuliaSymbolics/Symbolics.jl/issues/379>:
xvar = @variables $x[1:4]
xvar = Symbolics.@variables $x[1:4]
quaternion(xvar[1]...)
end

Expand Down
56 changes: 28 additions & 28 deletions test/base.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
@testset verbose=true "Base" begin
@testset "Numbers $T" for T in Types
# Note that, because `Num` from Symbolics is a weird type, we have to
# be a little more explicit below than we normally would be. Also,
# because of signed zeros in the float types, we have to take the
# absolute value of the difference before comparing to zero.
# Note that, because `Symbolics.Num` is a weird type, we have to be a little more
# explicit below than we normally would be. Also, because of signed zeros in the
# float types, we have to take the absolute value of the difference before comparing
# to zero.

# Define basis elements
u = Quaternion{T}(1)
Expand Down Expand Up @@ -42,7 +42,7 @@
if !(T<:Integer)
@test rotor(T(1), 2, 3, 4) == Rotor{T}(SVector{4, T}(1, 2, 3, 4)/√T(30))
@test Rotor{T}(1, 0, 0, 0) == Rotor{T}(1)
if !(T<:Num)
if !(T<:Symbolics.Num)
@test rotor(T[1, 2, 3, 4]...) rotor(SVector{4, T}(1, 2, 3, 4)/√T(30)) rtol=0 atol=2eps(T)
@test rotor(T[0, 2, 3, 4]...) rotor(T(2), T(3), T(4)) rtol=0 atol=2eps(T)
@test rotor(T[1, 0, 0, 0]...) rotor(T(1)) rtol=0 atol=2eps(T)
Expand Down Expand Up @@ -81,24 +81,24 @@
end

for v [𝐢, 𝐣, 𝐤]
@test Num(1)*v != one(T)
@test !isequal(Num(1)*v, one(T))
@test Quaternion(Num(one(T))) == one(T)
@test Quaternion(Num(7one(T))) == 7one(T)
@test one(T) == Quaternion(Num(one(T)))
@test 7one(T) == Quaternion(Num(7one(T)))

@test QuatVec(Num[1,2,3,4]) != one(T)
@test QuatVec(Num[7,2,3,4]) != 7one(T)
@test one(T) != QuatVec(Num[1,2,3,4])
@test 7one(T) != QuatVec(Num[7,2,3,4])

@test QuatVec(Num[1,2,3,4]) == QuatVec{T}(0,2,3,4)
@test QuatVec{T}(0,2,3,4) == QuatVec(Num[1,2,3,4])
@test QuatVec(Num[1,2,3,4]) == Quaternion{T}(0,2,3,4)
@test Quaternion{T}(0,2,3,4) == QuatVec(Num[1,2,3,4])
@test QuatVec(Num[1,2,3,4]) != Quaternion{T}(1,2,3,4)
@test Quaternion{T}(1,2,3,4) != QuatVec(Num[1,2,3,4])
@test Symbolics.Num(1)*v != one(T)
@test !isequal(Symbolics.Num(1)*v, one(T))
@test Quaternion(Symbolics.Num(one(T))) == one(T)
@test Quaternion(Symbolics.Num(7one(T))) == 7one(T)
@test one(T) == Quaternion(Symbolics.Num(one(T)))
@test 7one(T) == Quaternion(Symbolics.Num(7one(T)))

@test QuatVec(Symbolics.Num[1,2,3,4]) != one(T)
@test QuatVec(Symbolics.Num[7,2,3,4]) != 7one(T)
@test one(T) != QuatVec(Symbolics.Num[1,2,3,4])
@test 7one(T) != QuatVec(Symbolics.Num[7,2,3,4])

@test QuatVec(Symbolics.Num[1,2,3,4]) == QuatVec{T}(0,2,3,4)
@test QuatVec{T}(0,2,3,4) == QuatVec(Symbolics.Num[1,2,3,4])
@test QuatVec(Symbolics.Num[1,2,3,4]) == Quaternion{T}(0,2,3,4)
@test Quaternion{T}(0,2,3,4) == QuatVec(Symbolics.Num[1,2,3,4])
@test QuatVec(Symbolics.Num[1,2,3,4]) != Quaternion{T}(1,2,3,4)
@test Quaternion{T}(1,2,3,4) != QuatVec(Symbolics.Num[1,2,3,4])
end

# Test indexing
Expand Down Expand Up @@ -145,7 +145,7 @@
# Check isapprox
@test u one(T)
@test one(T) u
if T !== Num # Num(1) ≉ Num(2) doesn't work
if T !== Symbolics.Num # Symbolics.Num(1) ≉ Symbolics.Num(2) doesn't work
@test i one(T)
@test one(T) i
@test j one(T)
Expand Down Expand Up @@ -184,7 +184,7 @@
@test !isreal(k)

# Check "isinteger"
if T != Num
if T != Symbolics.Num
@test isinteger(u)
@test !isinteger(1.2u)
end
Expand Down Expand Up @@ -243,7 +243,7 @@
end
end

if T != Num
if T != Symbolics.Num
# Check "in"
@test u 0:2
@test u 2:4
Expand Down Expand Up @@ -347,8 +347,8 @@
end

@testset "Differential" begin
@variables t Q(t)[1:4] R(t)[1:4] V(t)[1:3]
∂ₜ = Differential(t)
Symbolics.@variables t Q(t)[1:4] R(t)[1:4] V(t)[1:3]
∂ₜ = Symbolics.Differential(t)
Q = quaternion(Q...)
R = rotor(R...)
V = quatvec(V...)
Expand Down
Loading

0 comments on commit 93c2e5a

Please sign in to comment.