From 93c2e5a8b34b22e2f0e68eb07f7106cf4aea20d4 Mon Sep 17 00:00:00 2001 From: Mike Boyle Date: Mon, 24 Jun 2024 12:23:00 -0400 Subject: [PATCH] Add FastDifferentiation support (minimal testing) --- Project.toml | 3 + ext/QuaternionicFastDifferentiationExt.jl | 181 ++++++++++++++++++++++ ext/QuaternionicSymbolicsExt.jl | 4 +- src/Quaternionic.jl | 1 + src/examples.jl | 2 +- test/Project.toml | 1 + test/algebra.jl | 4 +- test/base.jl | 56 +++---- test/runtests.jl | 11 +- 9 files changed, 226 insertions(+), 37 deletions(-) create mode 100644 ext/QuaternionicFastDifferentiationExt.jl diff --git a/Project.toml b/Project.toml index 84a54ab..8e056fe 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] ChainRulesCore = "1" +FastDifferentiation = "0.3.14" ForwardDiff = "0.10" GenericLinearAlgebra = "0.3.11" LaTeXStrings = "1" @@ -28,11 +29,13 @@ julia = "1.6" [extensions] QuaternionicChainRulesCoreExt = "ChainRulesCore" +QuaternionicFastDifferentiationExt = "FastDifferentiation" QuaternionicForwardDiffExt = "ForwardDiff" QuaternionicSymbolicsExt = "Symbolics" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a" diff --git a/ext/QuaternionicFastDifferentiationExt.jl b/ext/QuaternionicFastDifferentiationExt.jl new file mode 100644 index 0000000..c64a642 --- /dev/null +++ b/ext/QuaternionicFastDifferentiationExt.jl @@ -0,0 +1,181 @@ +module QuaternionicFastDifferentiationExt + +using StaticArrays: SVector +using Latexify: latexify +import Quaternionic: AbstractQuaternion, Quaternion, Rotor, QuatVec, + quaternion, rotor, quatvec, + QuatVecF64, RotorF64, QuaternionF64, + wrapper, components, _pm_ascii, _pm_latex +using PrecompileTools +isdefined(Base, :get_extension) ? (using FastDifferentiation) : (using ..FastDifferentiation) + + +### TYPE PIRACY!!! +# The following Base functions should be added to FastDifferentiation itself; meanwhile +# we'll define them here. + +# This is a workaround that should be fixed in FastDifferentiation.jl; Elaine will probably +# make a PR to fix this. The problem is that FD only defines +# `Base.promote_rule(::Type{<:Real}, ::Type{Node})`. But julia/src/bool.jl defines +# `promote_rule(::Type{Bool}, ::Type{T}) where T<:Number` and `Node <: Number`, so there is +# an ambiguity. +Base.promote_rule(::Type{Bool}, ::Type{FastDifferentiation.Node}) = FastDifferentiation.Node + +# These are essentially copied from Symbolics.jl: +# https://github.com/JuliaSymbolics/Symbolics.jl/blob/e4c328103ece494eaaab2a265524a64bfbe43dbd/src/num.jl#L31-L34 +Base.eps(::Type{FastDifferentiation.Node}) = FastDifferentiation.Node(0) +Base.typemin(::Type{FastDifferentiation.Node}) = FastDifferentiation.Node(-Inf) +Base.typemax(::Type{FastDifferentiation.Node}) = FastDifferentiation.Node(Inf) +Base.float(x::FastDifferentiation.Node) = x + +# This one is needed because julia/base/float.jl only defines `isinf` for `Real`, but `Node +# <: Number`. (See https://github.com/brianguenter/FastDifferentiation.jl/issues/73) +Base.isinf(x::FastDifferentiation.Node) = !isnan(x) & !isfinite(x) + + +normalize(v::AbstractVector{FastDifferentiation.Node}) = v ./ √sum(x->x^2, v) + + +### Functions that used to appear in quaternion.jl +quaternion(w::FastDifferentiation.Node) = quaternion(SVector{4}(w, false, false, false)) +rotor(w::FastDifferentiation.Node) = rotor(SVector{4}(one(w), false, false, false)) +quatvec(w::FastDifferentiation.Node) = quatvec(SVector{4,typeof(w)}(false, false, false, false)) +for QT1 ∈ (AbstractQuaternion, Quaternion, QuatVec, Rotor) + @eval begin + wrapper(::Type{<:$QT1}, ::Val{OP}, ::Type{<:FastDifferentiation.Node}) where {OP} = quaternion + wrapper(::Type{<:FastDifferentiation.Node}, ::Val{OP}, ::Type{<:$QT1}) where {OP} = quaternion + end +end +let NT = FastDifferentiation.Node + for QT ∈ (QuatVec,) + for OP ∈ (Val{*}, Val{/}) + @eval begin + wrapper(::Type{<:$QT}, ::$OP, ::Type{<:$NT}) = quatvec + wrapper(::Type{<:$NT}, ::$OP, ::Type{<:$QT}) = quatvec + end + end + end + for QT ∈ (Rotor,) + for OP ∈ (Val{+}, Val{-}, Val{*}, Val{/}) + @eval begin + wrapper(::Type{<:$QT}, ::$OP, ::Type{<:$NT}) = quaternion + wrapper(::Type{<:$NT}, ::$OP, ::Type{<:$QT}) = quaternion + end + end + end +end +let T = FastDifferentiation.Node + for OP ∈ (Val{+}, Val{-}, Val{*}, Val{/}) + @eval wrapper(::Type{<:Quaternion}, ::$OP, ::Type{<:$T}) = quaternion + if T !== Quaternion + @eval wrapper(::Type{<:$T}, ::$OP, ::Type{<:Quaternion}) = quaternion + end + end +end +Base.promote_rule(::Type{Q}, ::Type{S}) where {Q<:AbstractQuaternion,S<:FastDifferentiation.Node} = + wrapper(Q){promote_type(eltype(Q), S)} + + +# function _pm_ascii(x::FastDifferentiation.Node) +# # Utility function to print a component of a quaternion +# s = "$x" +# if s[1] ∉ "+-" +# s = "+" * s +# end +# if occursin(r"[+^/-]", s[2:end]) +# if s[1] == '+' +# s = " + " * "(" * s[2:end] * ")" +# else +# s = " + " * "(" * s * ")" +# end +# else +# s = " " * s[1] * " " * s[2:end] +# end +# s +# end +# function _pm_latex(x::Num) +# # Utility function to print a component of a quaternion in LaTeX +# s = latexify(x, env=:raw, bracket=true) +# if s[1] ∉ "+-" +# s = "+" * s +# end +# if occursin(r"[+^/-]", s[2:end]) +# if s[1] == '+' +# s = " + " * "\\left(" * s[2:end] * "\\right)" +# else +# s = " + " * "\\left(" * s * "\\right)" +# end +# else +# s = " " * s[1] * " " * s[2:end] +# end +# s +# end + + +# # Broadcast-like operations from FastDifferentiation +# # (d::FastDifferentiation.Operator)(q::QT) where {QT<:AbstractQuaternion} = QT(d(q[1]), d(q[2]), d(q[3]), d(q[4])) +# # (d::FastDifferentiation.Operator)(q::QuatVec) = quatvec(d(q[2]), d(q[3]), d(q[4])) +# (d::FastDifferentiation.Differential)(q::Quaternion) = quaternion(d(q[1]), d(q[2]), d(q[3]), d(q[4])) +# (d::FastDifferentiation.Differential)(q::Rotor) = quaternion(d(q[1]), d(q[2]), d(q[3]), d(q[4])) +# (d::FastDifferentiation.Differential)(q::QuatVec) = quatvec(d(q[2]), d(q[3]), d(q[4])) + + +### Functions that used to appear in algebra.jl +for TA ∈ (AbstractQuaternion, Rotor, QuatVec) + let TB = FastDifferentiation.Node + @eval begin + Base.:+(q::QT, p::$TB) where {QT<:$TA} = wrapper($TA, Val(+), $TB)(q[1]+p, q[2], q[3], q[4]) + Base.:-(q::QT, p::$TB) where {QT<:$TA} = wrapper($TA, Val(-), $TB)(q[1]-p, q[2], q[3], q[4]) + Base.:+(p::$TB, q::QT) where {QT<:$TA} = wrapper($TB, Val(+), $TA)(p+q[1], q[2], q[3], q[4]) + Base.:-(p::$TB, q::QT) where {QT<:$TA} = wrapper($TB, Val(-), $TA)(p-q[1], -q[2], -q[3], -q[4]) + end + end +end +let S = FastDifferentiation.Node + @eval begin + Base.:*(p::Q, s::$S) where {Q<:AbstractQuaternion} = wrapper(Q, Val(*), $S)(s*components(p)) + Base.:*(s::$S, p::Q) where {Q<:AbstractQuaternion} = wrapper($S, Val(*), Q)(s*components(p)) + Base.:/(p::Q, s::$S) where {Q<:AbstractQuaternion} = wrapper(Q, Val(/), $S)(components(p)/s) + function Base.:/(s::$S, p::Q) where {Q<:AbstractQuaternion} + f = s / abs2(p) + wrapper($S, Val(/), Q)(p[1] * f, -p[2] * f, -p[3] * f, -p[4] * f) + end + end +end + + +# Pre-compilation + +@setup_workload begin + # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the + # size of the precompile file and potentially make loading faster. + FastDifferentiation.@variables w x y z a b c d e + s = randn(Float64) + v = randn(QuatVecF64) + r = randn(RotorF64) + q = randn(QuaternionF64) + 𝓈 = w + 𝓋 = quatvec(x, y, z) + 𝓇 = rotor(a, b, c, d) + 𝓆 = quaternion(w, x, y, z) + + @compile_workload begin + # all calls in this block will be precompiled, regardless of whether they belong to + # this package or not (on Julia 1.8 and higher) + r(v) + 𝓇(𝓋) + for a ∈ [s, v, r, q, 𝓈, 𝓋, 𝓇, 𝓆] + conj(a) + for b ∈ [s, v, r, q, 𝓈, 𝓋, 𝓇, 𝓆] + a * b + a / b + a + b + a - b + end + end + + end +end + + +end # module diff --git a/ext/QuaternionicSymbolicsExt.jl b/ext/QuaternionicSymbolicsExt.jl index f0bf4fd..46633a8 100644 --- a/ext/QuaternionicSymbolicsExt.jl +++ b/ext/QuaternionicSymbolicsExt.jl @@ -10,6 +10,8 @@ using PrecompileTools isdefined(Base, :get_extension) ? (using Symbolics) : (using ..Symbolics) +normalize(v::AbstractVector{Symbolics.Num}) = v ./ √sum(x->x^2, v) + ### Functions that used to appear in quaternion.jl quaternion(w::Symbolics.Num) = quaternion(SVector{4}(w, false, false, false)) rotor(w::Symbolics.Num) = rotor(SVector{4}(one(w), false, false, false)) @@ -257,7 +259,7 @@ end @compile_workload begin # all calls in this block will be precompiled, regardless of whether they belong to - # your package or not (on Julia 1.8 and higher) + # this package or not (on Julia 1.8 and higher) r(v) Symbolics.simplify.(𝓇(𝓋)) for a ∈ [s, v, r, q, 𝓈, 𝓋, 𝓇, 𝓆] diff --git a/src/Quaternionic.jl b/src/Quaternionic.jl index 379bc83..c55313f 100644 --- a/src/Quaternionic.jl +++ b/src/Quaternionic.jl @@ -53,6 +53,7 @@ end function __init__() @require ChainRulesCore="d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include("../ext/QuaternionicChainRulesCoreExt.jl") + @require FastDifferentiation="eb9bf01b-bf85-4b60-bf87-ee5de06c00be" include("../ext/QuaternionicFastDifferentiationExt.jl") @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("../ext/QuaternionicForwardDiffExt.jl") @require Symbolics="0c5d862f-8b57-4792-8d23-62f2024744c7" include("../ext/QuaternionicSymbolicsExt.jl") end diff --git a/src/examples.jl b/src/examples.jl index 61a782b..35a20c3 100644 --- a/src/examples.jl +++ b/src/examples.jl @@ -33,7 +33,7 @@ julia> R, ω⃗, Ṙ = precessing_nutating_example(); julia> R(12.34) rotor(0.9944579779058746 + 0.09804177421238346𝐢 - 0.0008485045352531196𝐣 + 0.03795287510453948𝐤) julia> ω⃗(345.67) - + 0.0004634300734286701𝐢 - 0.0007032818419003175𝐣 + 0.006214814810035088𝐤 + + 0.00046343007342866996𝐢 - 0.0007032818419003173𝐣 + 0.006214814810035087𝐤 julia> ϵ = 1e-6; (R(ϵ) - R(-ϵ)) / 2ϵ # Approximate derivative at t=0 -3.8491432263754177e-7 + (3.9080960689830135e-6)𝐢 - (6.861695854245622e-5)𝐣 + 0.003076329202503836𝐤 julia> Ṙ(0) diff --git a/test/Project.toml b/test/Project.toml index 9eabed5..f6c5a05 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" +FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/algebra.jl b/test/algebra.jl index 38e6a19..9794ce8 100644 --- a/test/algebra.jl +++ b/test/algebra.jl @@ -98,13 +98,13 @@ end chars = Iterators.Stateful(Iterators.cycle("abcdefghijkl")) function next_scalar!(chars) x = Symbol(popfirst!(chars)) - xvar = @variables $x + xvar = Symbolics.@variables $x xvar[1] end function next_quaternion!(chars) x = Symbol(popfirst!(chars)) # May have to work around : - xvar = @variables $x[1:4] + xvar = Symbolics.@variables $x[1:4] quaternion(xvar[1]...) end diff --git a/test/base.jl b/test/base.jl index 4002b7d..3f82bd1 100644 --- a/test/base.jl +++ b/test/base.jl @@ -1,9 +1,9 @@ @testset verbose=true "Base" begin @testset "Numbers $T" for T in Types - # Note that, because `Num` from Symbolics is a weird type, we have to - # be a little more explicit below than we normally would be. Also, - # because of signed zeros in the float types, we have to take the - # absolute value of the difference before comparing to zero. + # Note that, because `Symbolics.Num` is a weird type, we have to be a little more + # explicit below than we normally would be. Also, because of signed zeros in the + # float types, we have to take the absolute value of the difference before comparing + # to zero. # Define basis elements u = Quaternion{T}(1) @@ -42,7 +42,7 @@ if !(T<:Integer) @test rotor(T(1), 2, 3, 4) == Rotor{T}(SVector{4, T}(1, 2, 3, 4)/√T(30)) @test Rotor{T}(1, 0, 0, 0) == Rotor{T}(1) - if !(T<:Num) + if !(T<:Symbolics.Num) @test rotor(T[1, 2, 3, 4]...) ≈ rotor(SVector{4, T}(1, 2, 3, 4)/√T(30)) rtol=0 atol=2eps(T) @test rotor(T[0, 2, 3, 4]...) ≈ rotor(T(2), T(3), T(4)) rtol=0 atol=2eps(T) @test rotor(T[1, 0, 0, 0]...) ≈ rotor(T(1)) rtol=0 atol=2eps(T) @@ -81,24 +81,24 @@ end for v ∈ [𝐢, 𝐣, 𝐤] - @test Num(1)*v != one(T) - @test !isequal(Num(1)*v, one(T)) - @test Quaternion(Num(one(T))) == one(T) - @test Quaternion(Num(7one(T))) == 7one(T) - @test one(T) == Quaternion(Num(one(T))) - @test 7one(T) == Quaternion(Num(7one(T))) - - @test QuatVec(Num[1,2,3,4]) != one(T) - @test QuatVec(Num[7,2,3,4]) != 7one(T) - @test one(T) != QuatVec(Num[1,2,3,4]) - @test 7one(T) != QuatVec(Num[7,2,3,4]) - - @test QuatVec(Num[1,2,3,4]) == QuatVec{T}(0,2,3,4) - @test QuatVec{T}(0,2,3,4) == QuatVec(Num[1,2,3,4]) - @test QuatVec(Num[1,2,3,4]) == Quaternion{T}(0,2,3,4) - @test Quaternion{T}(0,2,3,4) == QuatVec(Num[1,2,3,4]) - @test QuatVec(Num[1,2,3,4]) != Quaternion{T}(1,2,3,4) - @test Quaternion{T}(1,2,3,4) != QuatVec(Num[1,2,3,4]) + @test Symbolics.Num(1)*v != one(T) + @test !isequal(Symbolics.Num(1)*v, one(T)) + @test Quaternion(Symbolics.Num(one(T))) == one(T) + @test Quaternion(Symbolics.Num(7one(T))) == 7one(T) + @test one(T) == Quaternion(Symbolics.Num(one(T))) + @test 7one(T) == Quaternion(Symbolics.Num(7one(T))) + + @test QuatVec(Symbolics.Num[1,2,3,4]) != one(T) + @test QuatVec(Symbolics.Num[7,2,3,4]) != 7one(T) + @test one(T) != QuatVec(Symbolics.Num[1,2,3,4]) + @test 7one(T) != QuatVec(Symbolics.Num[7,2,3,4]) + + @test QuatVec(Symbolics.Num[1,2,3,4]) == QuatVec{T}(0,2,3,4) + @test QuatVec{T}(0,2,3,4) == QuatVec(Symbolics.Num[1,2,3,4]) + @test QuatVec(Symbolics.Num[1,2,3,4]) == Quaternion{T}(0,2,3,4) + @test Quaternion{T}(0,2,3,4) == QuatVec(Symbolics.Num[1,2,3,4]) + @test QuatVec(Symbolics.Num[1,2,3,4]) != Quaternion{T}(1,2,3,4) + @test Quaternion{T}(1,2,3,4) != QuatVec(Symbolics.Num[1,2,3,4]) end # Test indexing @@ -145,7 +145,7 @@ # Check isapprox @test u ≈ one(T) @test one(T) ≈ u - if T !== Num # Num(1) ≉ Num(2) doesn't work + if T !== Symbolics.Num # Symbolics.Num(1) ≉ Symbolics.Num(2) doesn't work @test i ≉ one(T) @test one(T) ≉ i @test j ≉ one(T) @@ -184,7 +184,7 @@ @test !isreal(k) # Check "isinteger" - if T != Num + if T != Symbolics.Num @test isinteger(u) @test !isinteger(1.2u) end @@ -243,7 +243,7 @@ end end - if T != Num + if T != Symbolics.Num # Check "in" @test u ∈ 0:2 @test u ∉ 2:4 @@ -347,8 +347,8 @@ end @testset "Differential" begin - @variables t Q(t)[1:4] R(t)[1:4] V(t)[1:3] - ∂ₜ = Differential(t) + Symbolics.@variables t Q(t)[1:4] R(t)[1:4] V(t)[1:3] + ∂ₜ = Symbolics.Differential(t) Q = quaternion(Q...) R = rotor(R...) V = quatvec(V...) diff --git a/test/runtests.jl b/test/runtests.jl index 82519ab..ddff221 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,18 +21,19 @@ around the code you don't want to measure: using Quaternionic using Test -using Random, Symbolics, StaticArrays, ForwardDiff, GenericLinearAlgebra, +using Random, StaticArrays, ForwardDiff, GenericLinearAlgebra, ChainRulesTestUtils, Zygote, ChainRulesTestUtils, Aqua +import Symbolics, FastDifferentiation import LinearAlgebra using ChainRulesCore ChainRulesCore.debug_mode() = true -@variables w x y z a b c d e # Symbolic variables +Symbolics.@variables w x y z a b c d e # Symbolic variables # NOTE: `FloatTypes` and `IntTypes` must be in descending order of width FloatTypes = [BigFloat, Float64, Float32, Float16] IntTypes = [BigInt, Int128, Int64, Int32, Int16, Int8] -SymbolicTypes = [Num] +SymbolicTypes = [Symbolics.Num] Types = [FloatTypes...; IntTypes...; SymbolicTypes...] PrimitiveTypes = [T for T in Types if isbitstype(T)] @@ -41,8 +42,8 @@ QTypes = [Quaternion, Rotor, QuatVec] # Handy assignments for now Base.eps(::Quaternion{T}) where {T} = eps(T) Base.eps(T::Type{<:Integer}) = zero(T) -Base.eps(n::Num) = zero(n) -Base.:≈(a::Num, b::Num; kwargs...) = iszero(Symbolics.simplify(a-b; expand=true)) +Base.eps(n::Symbolics.Num) = zero(n) +Base.:≈(a::Symbolics.Num, b::Symbolics.Num; kwargs...) = iszero(Symbolics.simplify(a-b; expand=true)) enabled_tests = lowercase.(ARGS)