diff --git a/Project.toml b/Project.toml index 195cccac5..63fd50344 100644 --- a/Project.toml +++ b/Project.toml @@ -7,20 +7,17 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -[weakdeps] -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[extensions] -ChainRulesCoreSparseArraysExt = "SparseArrays" - [compat] BenchmarkTools = "0.5" -Compat = "2, 3, 4" +Compat = "3.40, 4" FiniteDifferences = "0.10" OffsetArrays = "1" StaticArrays = "0.11, 0.12, 1" julia = "1.6" +[extensions] +ChainRulesCoreSparseArraysExt = "SparseArrays" + [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" @@ -31,3 +28,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "SparseArrays", "StaticArrays"] + +[weakdeps] +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/docs/make.jl b/docs/make.jl index ad86a84ae..1666665fe 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -61,6 +61,7 @@ makedocs(; "`@opt_out`" => "rule_author/superpowers/opt_out.md", "`RuleConfig`" => "rule_author/superpowers/ruleconfig.md", "Gradient accumulation" => "rule_author/superpowers/gradient_accumulation.md", + "Mutation Support (experimental)" => "rule_author/superpowers/mutation_support.md", ], "Converting ZygoteRules.@adjoint to rrules" => "rule_author/converting_zygoterules.md", "Tips for making your package work with AD" => "rule_author/tips_for_packages.md", diff --git a/docs/src/api.md b/docs/src/api.md index 5648058e0..57b7bf2ad 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -20,7 +20,7 @@ Modules = [ChainRulesCore] Pages = [ "tangent_types/abstract_zero.jl", "tangent_types/one.jl", - "tangent_types/tangent.jl", + "tangent_types/structural_tangent.jl", "tangent_types/thunks.jl", "tangent_types/abstract_tangent.jl", "tangent_types/notimplemented.jl", diff --git a/docs/src/rule_author/superpowers/mutation_support.md b/docs/src/rule_author/superpowers/mutation_support.md new file mode 100644 index 000000000..a4dec8ab8 --- /dev/null +++ b/docs/src/rule_author/superpowers/mutation_support.md @@ -0,0 +1,82 @@ +# Mutation Support + +ChainRulesCore.jl offers experimental support for mutation, targeting use in forward mode AD. +(Mutation support in reverse mode AD is more complicated and will likely require more changes to the interface) + +!!! warning "Experimental" + This page documents an experimental feature. + Expect breaking changes in minor versions while this remains. + It is not suitable for general use unless you are prepared to modify how you are using it each minor release. + It is thus suggested that if you are using it to use _tilde_ bounds on supported minor versions. + + +## `MutableTangent` +The [`MutableTangent`](@ref) type is designed to be a partner to the [`Tangent`](@ref) type, with specific support for being mutated in place. +It is required to be a structural tangent, having one tangent for each field of the primal object. + +Technically, not all `mutable struct`s need to use `MutableTangent` to represent their tangents. +Just like not all `struct`s need to use `Tangent`s. +Common examples away from this are natural tangent types like for arrays. +However, if one is setting up to use a custom tangent type for this it is sufficiently off the beaten path that we can not provide much guidance. + +## `zero_tangent` + +The [`zero_tangent`](@ref) function functions to give you a zero (i.e. additive identity) for any primal value. +The [`ZeroTangent`](@ref) type also does this. +The difference is that [`zero_tangent`](@ref) is in general full structural tangent mirroring the structure of the primal. +To be technical the promise of [`zero_tangent`](@ref) is that it will be a value that supports mutation. +However, in practice[^1] this is achieved through in a structural tangent +For mutation support this is important, since it means that there is mutable memory available in the tangent to be mutated when the primal changes. +To support this you thus need to make sure your zeros are created in various places with [`zero_tangent`](@ref) rather than []`ZeroTangent`](@ref). + + + +It is also useful for reasons of type stability, since it forces a consistent type (generally a structural tangent) for any given primal type. +For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not, +and to process the output of `frule`s to convert [`ZeroTangent`](@ref) into corresponding [`zero_tangent`](@ref)s. + +## Writing a frule for a mutating function +It is relatively straight forward to write a frule for a mutating function. +There are a few key points to follow: + - There must be a mutable tangent input for every mutated primal input + - When the primal value is changed, the corresponding change must be made to its tangent partner + - When a value is returned, return its partnered tangent. + + +### Example +For example, consider the primal function with: +1. takes two `Ref`s +2. doubles the first one in place +3. overwrites the second one's value with the literal 5.0 +4. returns the first one + + +```julia +function foo!(a::Base.RefValue, b::Base.RefValue) + a[] *= 2 + b[] = 5.0 + return a +end +``` + +The frule for this would be: +```julia +function ChainRulesCore.frule((ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::Base.RefValue) + @assert ȧ isa MutableTangent{typeof(a)} + @assert ḃ isa MutableTangent{typeof(b)} + + a[] *= 2 + ȧ.x *= 2 # `.x` is the field that lives behind RefValues + + b[]=5.0 + ḃ.x = zero_tangent(5.0) # or since we know that the zero for a Float64 is zero could write `ḃ.x = 0.0` + + return a, ȧ +end +``` + +Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct. + +[^1]: + Further, it is hard to achieve this promise of allowing mutation to be supported without returning a structural tangent. + Except in the special case of where the struct is not mutable and has no nested fields that are mutable. \ No newline at end of file diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 94e8242b1..286f71db2 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,7 +2,7 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using Base.Meta using LinearAlgebra -using Compat: hasfield, hasproperty +using Compat: hasfield, hasproperty, ismutabletype export frule, rrule # core function # rule configurations @@ -10,18 +10,18 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented -export ProjectTo, canonicalize, unthunk # tangent operations +export ProjectTo, canonicalize, unthunk, zero_tangent # tangent operations export add!!, is_inplaceable_destination # gradient accumulation operations export ignore_derivatives, @ignore_derivatives # tangents -export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk +export StructuralTangent, Tangent, MutableTangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk include("debug_mode.jl") include("tangent_types/abstract_tangent.jl") +include("tangent_types/structural_tangent.jl") include("tangent_types/abstract_zero.jl") include("tangent_types/thunks.jl") -include("tangent_types/tangent.jl") include("tangent_types/notimplemented.jl") include("tangent_arithmetic.jl") diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 439f0ac8f..18ae7b3ad 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -20,7 +20,7 @@ Base.:+(x::NotImplemented, ::NotImplemented) = x Base.:*(x::NotImplemented, ::NotImplemented) = x LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x # `NotImplemented` always "wins" + -for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any) +for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :StructuralTangent, :Any) @eval Base.:+(x::NotImplemented, ::$T) = x @eval Base.:+(::$T, x::NotImplemented) = x end @@ -33,7 +33,7 @@ for T in (:ZeroTangent, :NoTangent) @eval LinearAlgebra.dot(::$T, ::NotImplemented) = $T() end # `NotImplemented` "wins" * and dot for other types -for T in (:AbstractThunk, :Tangent, :Any) +for T in (:AbstractThunk, :StructuralTangent, :Any) @eval Base.:*(x::NotImplemented, ::$T) = x @eval Base.:*(::$T, x::NotImplemented) = x @eval LinearAlgebra.dot(x::NotImplemented, ::$T) = x @@ -55,7 +55,7 @@ Base.:-(::NoTangent, ::NoTangent) = NoTangent() Base.:-(::NoTangent) = NoTangent() Base.:*(::NoTangent, ::NoTangent) = NoTangent() LinearAlgebra.dot(::NoTangent, ::NoTangent) = NoTangent() -for T in (:AbstractThunk, :Tangent, :Any) +for T in (:AbstractThunk, :StructuralTangent, :Any) @eval Base.:+(::NoTangent, b::$T) = b @eval Base.:+(a::$T, ::NoTangent) = a @eval Base.:-(::NoTangent, b::$T) = -b @@ -95,7 +95,7 @@ Base.:-(::ZeroTangent, ::ZeroTangent) = ZeroTangent() Base.:-(::ZeroTangent) = ZeroTangent() Base.:*(::ZeroTangent, ::ZeroTangent) = ZeroTangent() LinearAlgebra.dot(::ZeroTangent, ::ZeroTangent) = ZeroTangent() -for T in (:AbstractThunk, :Tangent, :Any) +for T in (:AbstractThunk, :StructuralTangent, :Any) @eval Base.:+(::ZeroTangent, b::$T) = b @eval Base.:+(a::$T, ::ZeroTangent) = a @eval Base.:-(::ZeroTangent, b::$T) = -b @@ -126,11 +126,11 @@ for T in (:Tangent, :Any) @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end -function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P} +function Base.:+(a::StructuralTangent{P}, b::StructuralTangent{P}) where {P} data = elementwise_add(backing(a), backing(b)) - return Tangent{P,typeof(data)}(data) + return StructuralTangent{P}(data) end -function Base.:+(a::P, d::Tangent{P}) where {P} +function Base.:+(a::P, d::StructuralTangent{P}) where {P} net_backing = elementwise_add(backing(a), backing(d)) if debug_mode() try @@ -143,14 +143,14 @@ function Base.:+(a::P, d::Tangent{P}) where {P} end end Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d)) -Base.:+(a::Tangent{P}, b::P) where {P} = b + a +Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a -Base.:-(tangent::Tangent{P}) where {P} = map(-, tangent) +Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent) # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful # In general one doesn't have to represent multiplications of 2 tangents # Only of a tangent and a scaling factor (generally `Real`) for T in (:Number,) - @eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent) - @eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent) + @eval Base.:*(s::$T, tangent::StructuralTangent) = map(x -> s * x, tangent) + @eval Base.:*(tangent::StructuralTangent, s::$T) = map(x -> x * s, tangent) end diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 77c455c04..e51a99f09 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -91,3 +91,116 @@ arguments. ``` """ struct NoTangent <: AbstractZero end + +""" + zero_tangent(primal, _cache=nothing) + +This returns an appropriate zero tangent suitable for accumulating tangents of the primal. +For mutable composites types this is a structural [`MutableTangent`](@ref) +For `Array`s, it is applied recursively for each element. +For other types, in particular immutable types, we do not make promises beyond that it will be `iszero` +and suitable for accumulating against. +For types without a tangent space (e.g. singleton structs) this returns `NoTangent()`. +In general, it is more likely to produce a structural tangent. + +!!! warning Exprimental + `zero_tangent`is an experimental feature, and is part of the mutation support featureset. + While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. + Exactly how it should be used (e.g. is it forward-mode only?) + +The `_cache=nothing` is an internal implementation detail that the user should never need to set. +(It is used to hold references to tangents for that might appear in self-referential structures) +""" +function zero_tangent end + +zero_tangent(x::Number, _cache=nothing) = zero(x) + +zero_tangent(::Type, _cache=nothing) = NoTangent() + +function zero_tangent(x::MutableTangent{P}, _cache=nothing) where {P} + zb = backing(zero_tangent(backing(x), _cache)) + return MutableTangent{P}(zb) +end + +function zero_tangent(x::Tangent{P}, _cache=nothing) where {P} + zb = backing(zero_tangent(backing(x), _cache)) + return Tangent{P,typeof(zb)}(zb) +end + +@generated function zero_tangent(primal, _cache=nothing) + fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero. + zfield_exprs = map(fieldnames(primal)) do fname + :( + if isdefined(primal, $(QuoteNode(fname))) + zero_tangent(getfield(primal, $(QuoteNode(fname))), _cache) + else + # This is going to be potentially bad, but that's what they get for not giving us a primal + # This will never me mutated inplace, rather it will alway be replaced with an actual value first + ZeroTangent() + end + ) + end + return if has_mutable_tangent(primal) + # This is a little complex because we need to support-self referential types + # So we need to: + # 1. create the tangent, + # 2. put it in the cache + # 3. Do all the calls to create the zeros for the fields giving them that cache) + # 4. put those zeros into the object + tangent_types = map(guess_zero_tangent_type, fieldtypes(primal)) + is_defined_mask = Expr(:tuple, map(fieldnames(primal)) do fname + :(isdefined(primal, $(QuoteNode(fname)))) + end...) + + quote + isnothing(_cache) && (_cache = IdDict()) + found_tangent = get(_cache, primal, nothing) + !isnothing(found_tangent) && return found_tangent + + # Now we need to put into the cache a placeholder tangent so we can construct our fields using that cache + # then put those fields into the placeholder + tangent = $_MutableTangent(Val{$primal}(), $is_defined_mask, $tangent_types) + _cache[primal] = tangent + $( + map(fieldnames(primal), zfield_exprs) do fname, fval_expr + :(setproperty!(tangent, $(QuoteNode(fname)), $fval_expr)) + end... + ) + return tangent + end + else + :($Tangent{$primal}($(Expr(:parameters, Expr.(:kw, fieldnames(primal), zfield_exprs)...)))) + end +end + +function zero_tangent(primal::Tuple, _cache=nothing) + return Tangent{typeof(primal)}(map(x -> zero_tangent(x, _cache), primal)...) +end + +function zero_tangent(x::Array{P,N}, _cache=nothing) where {P,N} + if (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) + return map(zero_tangent, x) + end + + # Now we need to handle nonfully assigned arrays + # see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265 + y = Array{guess_zero_tangent_type(P),N}(undef, size(x)...) + @inbounds for n in eachindex(y) + if isassigned(x, n) + y[n] = zero_tangent(x[n], _cache) + end + end + return y +end + +# Sad heauristic methods +#guess_zero_tangent_type(::Type{T}) where {T<:Number} = T +#guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T))) +function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} + return Array{guess_zero_tangent_type(T),N} +end + +# The following will fall back to `Any` if it is hard to infer +function guess_zero_tangent_type(::Type{T}) where {T} + return Core.Compiler.return_type(zero_tangent, Tuple{T}) +end \ No newline at end of file diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/structural_tangent.jl similarity index 68% rename from src/tangent_types/tangent.jl rename to src/tangent_types/structural_tangent.jl index 6af968c53..71228485a 100644 --- a/src/tangent_types/tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -1,5 +1,20 @@ """ - Tangent{P, T} <: AbstractTangent + StructuralTangent{P} <: AbstractTangent + +Representing the type of the tangent of a `struct` `P` (or a `Tuple`/`NamedTuple`). +as an object with mirroring fields. + +!!!!!! warning Exprimental + `StructuralTangent` is an experimental feature, and is part of the mutation support featureset. + The `StructuralTangent` constructor returns a `MutableTangent` for mutable structs. + `MutableTangent` is an experimental feature. + Thus use of `StructuralTangent` (rather than `Tangent` directly) is also experimental. + While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. +""" +abstract type StructuralTangent{P} <: AbstractTangent end + +""" + Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`. `P` is the the corresponding primal type that this is a tangent for. @@ -21,7 +36,7 @@ Any fields not explictly present in the `Tangent` are treated as being set to `Z To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) function is provided. """ -struct Tangent{P,T} <: AbstractTangent +struct Tangent{P,T} <: StructuralTangent{P} # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict # (but potentially a different one, as it doesn't contain tangents) backing::T @@ -39,41 +54,115 @@ struct Tangent{P,T} <: AbstractTangent end end -function Tangent{P}(; kwargs...) where {P} - backing = (; kwargs...) # construct as NamedTuple - return Tangent{P,typeof(backing)}(backing) -end +function _MutableTangent end +""" + MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent -function Tangent{P}(args...) where {P} - return Tangent{P,typeof(args)}(args) -end +This type represents the tangent to a mutable struct. +It itself is also mutable. -function Tangent{P}() where {P<:Tuple} - backing = () - return Tangent{P,typeof(backing)}(backing) -end +!!! warning Exprimental + MutableTangent is an experimental feature, and is part of the mutation support featureset. + While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. + Exactly how it should be used (e.g. is it forward-mode only?) -function Tangent{P}(d::Dict) where {P<:Dict} - return Tangent{P,typeof(d)}(d) +!!! warning Do not directly mess with the tangent backing data + It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values. + However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is). + If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this. +""" +struct MutableTangent{P,F} <: StructuralTangent{P} + backing::F + + # Uninitialized constructor + global function _MutableTangent(::Val{P}, is_defined_mask, tangent_types) where {P} + backing_vals = map(is_defined_mask, tangent_types) do is_def, tangent_type + ref = if !is_def + Ref{Union{ZeroTangent, tangent_type}} # allow a Zero which will be used for uninitialized values + else + Ref{tangent_type} + end + return ref() # undefined, but it will be filled later + end + backing = NamedTuple{fieldnames(P)}(backing_vals) + return new{P, typeof(backing)}(backing) + end + + # TODO: are the following two correct? + # Are they useful? + # The place they are used is just `map`, maybe we should instead just copy types the thing being mapped? + function MutableTangent{P}( + any_mask::NamedTuple{names, <:NTuple{<:Any, Bool}}, fvals::NamedTuple{names} + ) where {names, P} + + backing = map(any_mask, fvals) do isany, fval + ref = if isany + Ref{Any} + else + Ref + end + return ref(fval) + end + return new{P, typeof(backing)}(backing) + end + + + function MutableTangent{P}(fvals) where P + any_mask = NamedTuple{fieldnames(P)}((!isconcretetype).(fieldtypes(P))) + return MutableTangent{P}(any_mask, fvals) + end end -function _backing_error(P, G, E) - msg = "Tangent for the primal $P should be backed by a $E type, not by $G." - return throw(ArgumentError(msg)) +#################################################################### +# StructuralTangent Common + + +function StructuralTangent{P}(nt::NamedTuple) where {P} + if has_mutable_tangent(P) + return MutableTangent{P}(nt) + else + return Tangent{P,typeof(nt)}(nt) + end end -function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} - return backing(a) == backing(b) + +has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(P) > 0) + + +StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup) +StructuralTangent{P}(dict::Dict) where P = Tangent{P}(dict) + +Base.keys(tangent::StructuralTangent) = keys(backing(tangent)) +Base.propertynames(tangent::StructuralTangent) = propertynames(backing(tangent)) + +Base.haskey(tangent::StructuralTangent, key) = haskey(backing(tangent), key) +if isdefined(Base, :hasproperty) + function Base.hasproperty(tangent::StructuralTangent, key::Symbol) + return hasproperty(backing(tangent), key) + end end -function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P} - all_fields = union(keys(backing(a)), keys(backing(b))) - return all(getproperty(a, f) == getproperty(b, f) for f in all_fields) + +Base.iszero(t::StructuralTangent) = all(iszero, backing(t)) + +function Base.map(f, tangent::StructuralTangent{P}) where {P} + #TODO: is it even useful to support this on MutableTangents? + #TODO: we implictly assume only linear `f` are called and that it is safe to ignore noncanonical Zeros + # This feels like a fair assumption since all normal operations on tangents are linear + L = propertynames(backing(tangent)) + vals = map(f, Tuple(backing(tangent))) + named_vals = NamedTuple{L,typeof(vals)}(vals) + return if tangent isa MutableTangent + MutableTangent{P}(named_vals) + else + Tangent{P,typeof(named_vals)}(named_vals) + end end -Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false -Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) -function Base.show(io::IO, tangent::Tangent{P}) where {P} +function Base.show(io::IO, tangent::StructuralTangent{P}) where {P} + if tangent isa MutableTangent + print(io, "Mutable") + end print(io, "Tangent{") str = sprint(show, P, context = io) i = findfirst('{', str) @@ -96,80 +185,6 @@ function Base.show(io::IO, tangent::Tangent{P}) where {P} end end -Base.iszero(::Tangent{<:,NamedTuple{}}) = true -Base.iszero(::Tangent{<:,Tuple{}}) = true -Base.iszero(t::Tangent) = all(iszero, backing(t)) - -Base.first(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = first(backing(canonicalize(tangent))) -Base.last(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = last(backing(canonicalize(tangent))) - -Base.tail(t::Tangent{P}) where {P<:Tuple} = Tangent{_tailtype(P)}(Base.tail(backing(canonicalize(t)))...) -@generated _tailtype(::Type{P}) where {P<:Tuple} = Tuple{P.parameters[2:end]...} -Base.tail(t::Tangent{<:Tuple{Any}}) = NoTangent() -Base.tail(t::Tangent{<:Tuple{}}) = NoTangent() - -Base.tail(t::Tangent{P}) where {P<:NamedTuple} = Tangent{_tailtype(P)}(; Base.tail(backing(canonicalize(t)))...) -_tailtype(::Type{NamedTuple{S,P}}) where {S,P} = NamedTuple{Base.tail(S), _tailtype(P)} -Base.tail(t::Tangent{<:NamedTuple{<:Any, <:Tuple{Any}}}) = NoTangent() -Base.tail(t::Tangent{<:NamedTuple{<:Any, <:Tuple{}}}) = NoTangent() - -function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}} - back = backing(canonicalize(tangent)) - return unthunk(getfield(back, idx)) -end -function Base.getindex(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} - hasfield(T, idx) || return ZeroTangent() - return unthunk(getfield(backing(tangent), idx)) -end -function Base.getindex(tangent::Tangent, idx) - return unthunk(getindex(backing(tangent), idx)) -end - -function Base.getproperty(tangent::Tangent, idx::Int) - back = backing(canonicalize(tangent)) - return unthunk(getfield(back, idx)) -end -function Base.getproperty(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} - hasfield(T, idx) || return ZeroTangent() - return unthunk(getfield(backing(tangent), idx)) -end - -Base.keys(tangent::Tangent) = keys(backing(tangent)) -Base.propertynames(tangent::Tangent) = propertynames(backing(tangent)) - -Base.haskey(tangent::Tangent, key) = haskey(backing(tangent), key) -if isdefined(Base, :hasproperty) - Base.hasproperty(tangent::Tangent, key::Symbol) = hasproperty(backing(tangent), key) -end - -Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...) -Base.length(tangent::Tangent) = length(backing(tangent)) - -Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T) -function Base.reverse(tangent::Tangent) - rev_backing = reverse(backing(tangent)) - return Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) -end - -function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state=1) where {P} - return Base.indexed_iterate(backing(tangent), i, state) -end - -function Base.map(f, tangent::Tangent{P,<:Tuple}) where {P} - vals::Tuple = map(f, backing(tangent)) - return Tangent{P,typeof(vals)}(vals) -end -function Base.map(f, tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} - vals = map(f, Tuple(backing(tangent))) - named_vals = NamedTuple{L,typeof(vals)}(vals) - return Tangent{P,typeof(named_vals)}(named_vals) -end -function Base.map(f, tangent::Tangent{P,<:Dict}) where {P<:Dict} - return Tangent{P}(Dict(k => f(v) for (k, v) in backing(tangent))) -end - -Base.conj(tangent::Tangent) = map(conj, tangent) - """ backing(x) @@ -184,6 +199,7 @@ backing(x::Tuple) = x backing(x::NamedTuple) = x backing(x::Dict) = x backing(x::Tangent) = getfield(x, :backing) +backing(x::MutableTangent) = map(getindex, getfield(x, :backing)) # For generic structs function backing(x::T)::NamedTuple where {T} @@ -211,39 +227,6 @@ function backing(x::T)::NamedTuple where {T} end end -""" - canonicalize(tangent::Tangent{P}) -> Tangent{P} - -Return the canonical `Tangent` for the primal type `P`. -The property names of the returned `Tangent` match the field names of the primal, -and all fields of `P` not present in the input `tangent` are explictly set to `ZeroTangent()`. -""" -function canonicalize(tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} - nil = _zeroed_backing(P) - combined = merge(nil, backing(tangent)) - if length(combined) !== fieldcount(P) - throw( - ArgumentError( - "Tangent fields do not match primal fields.\n" * - "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))", - ), - ) - end - return Tangent{P,typeof(combined)}(combined) -end - -# Tuple tangents are always in their canonical form -canonicalize(tangent::Tangent{<:Tuple,<:Tuple}) = tangent - -# Dict tangents are always in their canonical form. -canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent - -# Tangents of unspecified primal types (indicated by specifying exactly `Any`) -# all combinations of type-params are specified here to avoid ambiguities -canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent -canonicalize(tangent::Tangent{Any,<:Tuple}) = tangent -canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent - """ _zeroed_backing(P) @@ -339,7 +322,7 @@ elementwise_add(a::Dict, b::Dict) = merge(+, a, b) struct PrimalAdditionFailedException{P} <: Exception primal::P - tangent::Tangent{P} + tangent original::Exception end @@ -358,3 +341,160 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} printstyled(io, err.original; color=:yellow) return println(io) end + +####################################### +# immutable Tangent + +function Tangent{P}(; kwargs...) where {P} + backing = (; kwargs...) # construct as NamedTuple + return Tangent{P,typeof(backing)}(backing) +end + +function Tangent{P}(args...) where {P} + return Tangent{P,typeof(args)}(args) +end + +function Tangent{P}() where {P<:Tuple} + backing = () + return Tangent{P,typeof(backing)}(backing) +end + +function Tangent{P}(d::Dict) where {P<:Dict} + return Tangent{P,typeof(d)}(d) +end + +function _backing_error(P, G, E) + msg = "Tangent for the primal $P should be backed by a $E type, not by $G." + return throw(ArgumentError(msg)) +end + +function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} + return backing(a) == backing(b) +end +function Base.:(==)(a::Tangent{P}, b::Tangent{P}) where {P} + all_fields = union(keys(backing(a)), keys(backing(b))) + return all(getproperty(a, f) == getproperty(b, f) for f in all_fields) +end +Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false + +Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) + +Base.iszero(::Tangent{<:,NamedTuple{}}) = true +Base.iszero(::Tangent{<:,Tuple{}}) = true + + +Base.first(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = first(backing(canonicalize(tangent))) +Base.last(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = last(backing(canonicalize(tangent))) + +Base.tail(t::Tangent{P}) where {P<:Tuple} = Tangent{_tailtype(P)}(Base.tail(backing(canonicalize(t)))...) +@generated _tailtype(::Type{P}) where {P<:Tuple} = Tuple{P.parameters[2:end]...} +Base.tail(t::Tangent{<:Tuple{Any}}) = NoTangent() +Base.tail(t::Tangent{<:Tuple{}}) = NoTangent() + +Base.tail(t::Tangent{P}) where {P<:NamedTuple} = Tangent{_tailtype(P)}(; Base.tail(backing(canonicalize(t)))...) +_tailtype(::Type{NamedTuple{S,P}}) where {S,P} = NamedTuple{Base.tail(S), _tailtype(P)} +Base.tail(t::Tangent{<:NamedTuple{<:Any, <:Tuple{Any}}}) = NoTangent() +Base.tail(t::Tangent{<:NamedTuple{<:Any, <:Tuple{}}}) = NoTangent() + +function Base.getindex(tangent::Tangent{P,T}, idx::Int) where {P,T<:Union{Tuple,NamedTuple}} + back = backing(canonicalize(tangent)) + return unthunk(getfield(back, idx)) +end +function Base.getindex(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} + hasfield(T, idx) || return ZeroTangent() + return unthunk(getfield(backing(tangent), idx)) +end +function Base.getindex(tangent::Tangent, idx) + return unthunk(getindex(backing(tangent), idx)) +end + +function Base.getproperty(tangent::Tangent, idx::Int) + back = backing(canonicalize(tangent)) + return unthunk(getfield(back, idx)) +end +function Base.getproperty(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedTuple} + hasfield(T, idx) || return ZeroTangent() + return unthunk(getfield(backing(tangent), idx)) +end + + +Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...) +Base.length(tangent::Tangent) = length(backing(tangent)) + +Base.eltype(::Type{<:Tangent{<:Any,T}}) where {T} = eltype(T) +function Base.reverse(tangent::Tangent) + rev_backing = reverse(backing(tangent)) + return Tangent{typeof(rev_backing),typeof(rev_backing)}(rev_backing) +end + +function Base.indexed_iterate(tangent::Tangent{P,<:Tuple}, i::Int, state=1) where {P} + return Base.indexed_iterate(backing(tangent), i, state) +end + +function Base.map(f, tangent::Tangent{P,<:Tuple}) where {P} + vals::Tuple = map(f, backing(tangent)) + return Tangent{P,typeof(vals)}(vals) +end +function Base.map(f, tangent::Tangent{P,<:Dict}) where {P<:Dict} + return Tangent{P}(Dict(k => f(v) for (k, v) in backing(tangent))) +end + +Base.conj(tangent::Tangent) = map(conj, tangent) + + + +""" + canonicalize(tangent::Tangent{P}) -> Tangent{P} + +Return the canonical `Tangent` for the primal type `P`. +The property names of the returned `Tangent` match the field names of the primal, +and all fields of `P` not present in the input `tangent` are explictly set to `ZeroTangent()`. +""" +function canonicalize(tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} + nil = _zeroed_backing(P) + combined = merge(nil, backing(tangent)) + if length(combined) !== fieldcount(P) + throw( + ArgumentError( + "Tangent fields do not match primal fields.\n" * + "Tangent fields: $L. Primal ($P) fields: $(fieldnames(P))", + ), + ) + end + return Tangent{P,typeof(combined)}(combined) +end + +# Tuple tangents are always in their canonical form +canonicalize(tangent::Tangent{<:Tuple,<:Tuple}) = tangent + +# Dict tangents are always in their canonical form. +canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent + +# Tangents of unspecified primal types (indicated by specifying exactly `Any`) +# all combinations of type-params are specified here to avoid ambiguities +canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent +canonicalize(tangent::Tangent{Any,<:Tuple}) = tangent +canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent + +################################################### +# MutableTangent + +MutableTangent{P}(;kwargs...) where P = MutableTangent{P}(NamedTuple(kwargs)) + +ref_backing(t::MutableTangent) = getfield(t, :backing) + +Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(ref_backing(tangent), idx)[] +Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(ref_backing(tangent), idx)[] # break ambig + +function Base.setproperty!(tangent::MutableTangent, name::Symbol, x) + return getfield(ref_backing(tangent), name)[] = x +end +function Base.setproperty!(tangent::MutableTangent, idx::Int, x) + return getfield(ref_backing(tangent), idx)[] = x +end # break ambig + +Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h) +function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2} + typeintersect(T1, T2) == Union{} && return false + backing(t1)==backing(t2) +end diff --git a/test/runtests.jl b/test/runtests.jl index 6a4684d03..a3b0971a5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,7 @@ using Test @testset "differentials" begin include("tangent_types/abstract_zero.jl") include("tangent_types/thunks.jl") - include("tangent_types/tangent.jl") + include("tangent_types/structural_tangent.jl") include("tangent_types/notimplemented.jl") end diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 028d942ea..741d497b0 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -160,3 +160,156 @@ @test isempty(detect_ambiguities(M)) end end + +@testset "zero_tangent" begin + @testset "basics" begin + @test zero_tangent(1) === 0 + @test zero_tangent(1.0) === 0.0 + mutable struct MutDemo + x::Float64 + end + struct Demo + x::Float64 + end + @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} + @test iszero(zero_tangent(MutDemo(1.5))) + + @test zero_tangent((; a=1.3)) isa Tangent{typeof((; a = 1.3))} + @test zero_tangent(Demo(1.2)) isa Tangent{Demo} + @test zero_tangent(Demo(1.2)).x === 0.0 + + @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] + @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] + + @test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0) + + # Higher order + # StructuralTangents are valid tangents for themselves (just like Numbers) + # and indeed we prefer that, otherwise higher order structural tangents are kinda + # nightmarishly complex types. + @test zero_tangent(zero_tangent(Demo(1.5))) == zero_tangent(Demo(1.5)) + @test zero_tangent(zero_tangent((1.5, 2.5))) == Tangent{Tuple{Float64, Float64}}(0.0, 0.0) + @test zero_tangent(zero_tangent(MutDemo(1.5))) == zero_tangent(MutDemo(1.5)) + end + + @testset "Weird types" begin + @test iszero(zero_tangent(typeof(Int))) # primative type + @test iszero(zero_tangent(typeof(Base.RefValue))) # struct + @test iszero(zero_tangent(Vector)) # UnionAll + @test iszero(zero_tangent(Union{Int, Float64})) # Union + @test iszero(zero_tangent(:abc)) + @test iszero(zero_tangent("abc")) + @test iszero(zero_tangent(sin)) + end + + @testset "undef elements Vector" begin + x = Vector{Vector{Float64}}(undef, 3) + x[2] = [1.0, 2.0] + dx = zero_tangent(x) + @test dx isa Vector{Vector{Float64}} + @test length(dx) == 3 + @test !isassigned(dx, 1) # We may reconsider this later + @test dx[2] == [0.0, 0.0] + @test !isassigned(dx, 3) # We may reconsider this later + + a = Vector{MutDemo}(undef, 3) + a[2] = MutDemo(1.5) + da = zero_tangent(a) + @test !isassigned(da, 1) # We may reconsider this later + @test iszero(da[2]) + @test !isassigned(da, 3) # We may reconsider this later + + db = zero_tangent(Vector{MutDemo}(undef, 3)) + @test all(ii -> !isassigned(db, ii), eachindex(db)) # We may reconsider this later + @test length(db) == 3 + @test db isa Vector + end + + @testset "undef fields struct" begin + dx = zero_tangent(Core.Box()) + @test dx.contents isa ZeroTangent + @test (dx.contents = 2.0) == 2.0 # should be assignable + + mutable struct MyPartiallyDefinedStruct + intro::Float64 + contents::Number + MyPartiallyDefinedStruct(x) = new(x) + end + dy = zero_tangent(MyPartiallyDefinedStruct(1.5)) + @test iszero(dy.intro) + @test iszero(dy.contents) + @test (dy.contents = 2.0) == 2.0 # should be assignable + + mutable struct MyPartiallyDefinedStructWithAnys + intro::Float64 + contents::Any + MyPartiallyDefinedStructWithAnys(x) = new(x) + end + dy = zero_tangent(MyPartiallyDefinedStructWithAnys(1.5)) + @test iszero(dy.intro) + @test iszero(dy.contents) + @test dy.contents === ZeroTangent() # we just don't know anything about this data + @test (dy.contents = 2.0) == 2.0 # should be assignable + @test (dy.contents = [2.0, 4.0]) == [2.0, 4.0] # should be assignable to different values + + mutable struct MyStructWithNonConcreteFields + x::Any + y::Union{Float64,Vector{Float64}} + z::AbstractVector + end + d = zero_tangent(MyStructWithNonConcreteFields(1.0, 2.0, [3.0])) + @test iszero(d.x) + d.x = Tangent{Base.RefValue{Float64}}(; x=1.5) + @test d.x == Tangent{Base.RefValue{Float64}}(; x=1.5) #should be assignable + d.x = 2.4 + @test d.x == 2.4 #should be assignable + @test iszero(d.y) + d.y = 2.4 + @test d.y == 2.4 #should be assignable + d.y = [2.4] + @test d.y == [2.4] #should be assignable + @test iszero(d.z) + d.z = [1.0, 2.0] + @test d.z == [1.0, 2.0] + d.z = @view [2.0, 3.0, 4.0][1:2] + @test d.z == [2.0, 3.0] + @test d.z isa SubArray + end + + @testset "cyclic references" begin + mutable struct Link + data::Float64 + next::Link + Link(data) = new(data) + end + + lk = Link(1.5) + lk.next = lk + + d = zero_tangent(lk) + @test d.data == 0.0 + @test d.next === d + + # The following two cases are broken + # We hope they are not too significant, because in general if you AD step by step they should work + # (as should the one above so maybe we should get rid of this extra complex logic) + # It's only a problem if you first do the multistep build then `zero_tangent` rather than `zero_tangent` at the constructor. + + # Idea: check if `!isbitstype` only if so do we need to worry about caching etc + struct CarryingArray + x::Vector + end + ca = CarryingArray(Any[1.5]) + push!(ca.x, ca) + @test_broken d_ca = zero_tangent(ca) + @test_broken d_ca[1] == 0.0 + @test_broken d_ca[2] === _ca + + # Idea: check if typeof(xs) <: eltype(xs), if so need to cache it before computing + xs = Any[1.5] + push!(xs, xs) + @test_broken d_xs = zero_tangent(xs) + @test_broken d_xs[1] == 0.0 + @test_broken d_xs[2] == d_xs + end +end diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl new file mode 100644 index 000000000..c177b05f4 --- /dev/null +++ b/test/tangent_types/structural_tangent.jl @@ -0,0 +1,502 @@ +# For testing Tangent +struct Foo + x + y::Float64 +end + +mutable struct MFoo + x::Float64 + y +end + +# For testing Primal + Tangent performance +struct Bar + x::Float64 +end + +# For testing Tangent: it is an invarient of the type that x2 = 2x +# so simple addition can not be defined +struct StructWithInvariant + x + x2 + + StructWithInvariant(x) = new(x, 2x) +end +@testset "StructuralTangent" begin + @testset "Tangent" begin + @testset "empty types" begin + @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} + end + + @testset "constructor" begin + t = (1.0, 2.0) + nt = (x=1, y=2.0) + d = Dict(:x => 1.0, :y => 2.0) + vals = [1, 2] + + @test_throws ArgumentError Tangent{typeof(t),typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(t),typeof(d)}(d) + + @test_throws ArgumentError Tangent{typeof(d),typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(d),typeof(t)}(t) + + @test_throws ArgumentError Tangent{typeof(nt),typeof(vals)}(vals) + @test_throws ArgumentError Tangent{typeof(nt),typeof(d)}(d) + @test_throws ArgumentError Tangent{typeof(nt),typeof(t)}(t) + + @test_throws ArgumentError Tangent{Foo,typeof(d)}(d) + @test_throws ArgumentError Tangent{Foo,typeof(t)}(t) + end + + @testset "==" begin + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) + @test Tangent{Foo}(; y=2.5, x=ZeroTangent()) == Tangent{Foo}(; y=2.5) + + @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) + @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) + + tup = (1.0, 2.0) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) + + @test Tangent{Foo}(; y=2.0) == Tangent{Foo}(; x=ZeroTangent(), y=Float32(2.0)) + end + + @testset "hash" begin + @test hash(Tangent{Foo}(; x=0.1, y=2.5)) == hash(Tangent{Foo}(; y=2.5, x=0.1)) + @test hash(Tangent{Foo}(; y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(; y=2.5)) + end + + @testset "indexing, iterating, and properties" begin + @test keys(Tangent{Foo}(; x=2.5)) == (:x,) + @test propertynames(Tangent{Foo}(; x=2.5)) == (:x,) + @test haskey(Tangent{Foo}(; x=2.5), :x) == true + if isdefined(Base, :hasproperty) + @test hasproperty(Tangent{Foo}(; x=2.5), :y) == false + end + @test Tangent{Foo}(; x=2.5).x == 2.5 + + tang1 = Tangent{Tuple{Float64}}(2.0) + @test keys(tang1) == Base.OneTo(1) + @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) + @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + @test NoTangent() === @inferred Base.tail(tang1) + @test NoTangent() === @inferred Base.tail(Tangent{Tuple{}}()) + + tang3 = Tangent{Tuple{Float64, String, Vector{Float64}}}(1.0, NoTangent(), @thunk [3.0] .+ 4) + @test @inferred(first(tang3)) === tang3[1] === 1.0 + @test @inferred(last(tang3)) isa Thunk + @test unthunk(last(tang3)) == [7.0] + @test Tuple(@inferred Base.tail(tang3))[1] === NoTangent() + @test Tuple(Base.tail(tang3))[end] isa Thunk + + NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 + + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 + + @test first(Tangent{NT}(; a=(@thunk 2.0^2))) isa Thunk + @test unthunk(first(Tangent{NT}(; a=(@thunk 2.0^2)))) == 4.0 + @test last(Tangent{NT}(; a=(@thunk 2.0^2))) isa ZeroTangent + + ntang1 = @inferred Base.tail(Tangent{NT}(; b=(@thunk 2.0^2))) + @test ntang1 isa Tangent{<:NamedTuple{(:b,)}} + @test NoTangent() === @inferred Base.tail(ntang1) + + # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 + # if VERSION >= v"1.8-" + # @test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true + # else + # @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true + # end + @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false + + @test length(Tangent{Foo}(; x=2.5)) == 1 + @test length(Tangent{Tuple{Float64}}(2.0)) == 1 + + @test eltype(Tangent{Foo}(; x=2.5)) == Float64 + @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 + + # Testing iterate via collect + @test collect(Tangent{Foo}(; x=2.5)) == [2.5] + @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] + + # Test indexed_iterate + ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) + _unpack2tuple = function (tangent) + a, b = tangent + return (a, b) + end + @inferred _unpack2tuple(ctup) + @test _unpack2tuple(ctup) === (2.0, 3) + + # Test getproperty is inferrable + _unpacknamedtuple = tangent -> (tangent.x, tangent.y) + if VERSION ≥ v"1.2" + @inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0)) + end + end + + @testset "reverse" begin + c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") + cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) + @test reverse(c) === cr + + if VERSION < v"1.9-" + # can't reverse a named tuple or a dict + @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) + + d = Dict(:x => 1, :y => 2.0) + cdict = Tangent{typeof(d),typeof(d)}(d) + @test_throws MethodError reverse(Tangent{Foo}()) + else + # These now work but do we care? + end + end + + @testset "unset properties" begin + @test Tangent{Foo}(; x=1.4).y === ZeroTangent() + end + + @testset "conj" begin + @test conj(Tangent{Foo}(; x=2.0 + 3.0im)) == Tangent{Foo}(; x=2.0 - 3.0im) + @test ==( + conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), Tangent{Tuple{Float64}}(2.0 - 3.0im) + ) + @test ==( + conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), + Tangent{Dict}(Dict(4 => 2.0 + -3.0im)), + ) + end + + @testset "canonicalize" begin + # Testing iterate via collect + @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) + + @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) + + # For structure it needs to match order and ZeroTangent() fill to match primal + CFoo = Tangent{Foo} + @test canonicalize(CFoo(; x=2.5, y=10)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10, x=2.5)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10)) == CFoo(; x=ZeroTangent(), y=10) + + @test_throws ArgumentError canonicalize(CFoo(; q=99.0, x=2.5)) + + @testset "unspecified primal type" begin + c1 = Tangent{Any}(; a=1, b=2) + c2 = Tangent{Any}(1, 2) + c3 = Tangent{Any}(Dict(4 => 3)) + + @test c1 == canonicalize(c1) + @test c2 == canonicalize(c2) + @test c3 == canonicalize(c3) + end + end + + @testset "+ with other composites" begin + @testset "Structs" begin + CFoo = Tangent{Foo} + @test CFoo(; x=1.5) + CFoo(; x=2.5) == CFoo(; x=4.0) + @test CFoo(; y=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=2.5) + @test CFoo(; y=1.5, x=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=4.0) + end + + @testset "Tuples" begin + @test ==( + typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), Tangent{Tuple{},Tuple{}} + ) + @test ( + Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + + Tangent{Tuple{Float64,Float64}}(1.0, 1.0) + ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) + end + + @testset "NamedTuples" begin + make_tangent(nt::NamedTuple) = Tangent{typeof(nt)}(; nt...) + t1 = make_tangent((; a=1.5, b=0.0)) + t2 = make_tangent((; a=0.0, b=2.5)) + t_sum = make_tangent((a=1.5, b=2.5)) + @test t1 + t2 == t_sum + end + + @testset "Dicts" begin + d1 = Tangent{Dict}(Dict(4 => 3.0, 3 => 2.0)) + d2 = Tangent{Dict}(Dict(4 => 3.0, 2 => 2.0)) + d_sum = Tangent{Dict}(Dict(4 => 3.0 + 3.0, 3 => 2.0, 2 => 2.0)) + @test d1 + d2 == d_sum + end + + @testset "Fields of type NotImplemented" begin + CFoo = Tangent{Foo} + a = CFoo(; x=1.5) + b = CFoo(; x=@not_implemented("")) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa CFoo + @test z.x isa ChainRulesCore.NotImplemented + end + + a = Tangent{Tuple}(1.5) + b = Tangent{Tuple}(@not_implemented("")) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa Tangent{Tuple} + @test first(z) isa ChainRulesCore.NotImplemented + end + + a = Tangent{NamedTuple{(:x,)}}(; x=1.5) + b = Tangent{NamedTuple{(:x,)}}(; x=@not_implemented("")) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa Tangent{NamedTuple{(:x,)}} + @test z.x isa ChainRulesCore.NotImplemented + end + + a = Tangent{Dict}(Dict(:x => 1.5)) + b = Tangent{Dict}(Dict(:x => @not_implemented(""))) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa Tangent{Dict} + @test z[:x] isa ChainRulesCore.NotImplemented + end + end + end + + @testset "+ with Primals" begin + @testset "Structs" begin + @test Foo(3.5, 1.5) + Tangent{Foo}(; x=2.5) == Foo(6.0, 1.5) + @test Tangent{Foo}(; x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) + @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 + end + + @testset "Tuples" begin + @test Tangent{Tuple{}}() + () == () + @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) + @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) + end + + @testset "NamedTuple" begin + ntx = (; a=1.5) + @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) + + nty = (; a=1.5, b=0.5) + @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) + end + + @testset "Dicts" begin + d_primal = Dict(4 => 3.0, 3 => 2.0) + d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) + @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) + end + end + + @testset "+ with Primals, with inner constructor" begin + value = StructWithInvariant(10.0) + diff = Tangent{StructWithInvariant}(; x=2.0, x2=6.0) + + @testset "with and without debug mode" begin + @assert ChainRulesCore.debug_mode() == false + @test_throws MethodError (value + diff) + @test_throws MethodError (diff + value) + + ChainRulesCore.debug_mode() = true # enable debug mode + @test_throws ChainRulesCore.PrimalAdditionFailedException (value + diff) + @test_throws ChainRulesCore.PrimalAdditionFailedException (diff + value) + ChainRulesCore.debug_mode() = false # disable it again + end + + # Now we define constuction for ChainRulesCore.jl's purposes: + # It is going to determine the root quanity of the invarient + function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) + x = (nt.x + nt.x2 / 2) / 2 + return StructWithInvariant(x) + end + @test value + diff == StructWithInvariant(12.5) + @test diff + value == StructWithInvariant(12.5) + end + + @testset "differential arithmetic" begin + c = Tangent{Foo}(; y=1.5, x=2.5) + + @test NoTangent() * c == NoTangent() + @test c * NoTangent() == NoTangent() + @test dot(NoTangent(), c) == NoTangent() + @test dot(c, NoTangent()) == NoTangent() + @test norm(Tangent{Foo}(; y=c.y, x=NoTangent())) == c.y + @test norm(NoTangent(), Inf) == 0 + + @test ZeroTangent() * c == ZeroTangent() + @test c * ZeroTangent() == ZeroTangent() + @test dot(ZeroTangent(), c) == ZeroTangent() + @test dot(c, ZeroTangent()) == ZeroTangent() + @test norm(ZeroTangent()) == 0 + @test norm(ZeroTangent(), 0.4) == 0 + + @test true * c === c + @test c * true === c + + t = @thunk 2 + @test t * c == 2 * c + @test c * t == c * 2 + end + + @testset "-Tangent" begin + t = Tangent{Foo}(; x=1.0, y=-2.0) + @test -t == Tangent{Foo}(; x=-1.0, y=2.0) + @test -1.0 * t == -t + end + + @testset "scaling" begin + @test ( + 2 * Tangent{Foo}(; y=1.5, x=2.5) == + Tangent{Foo}(; y=3.0, x=5.0) == + Tangent{Foo}(; y=1.5, x=2.5) * 2 + ) + @test ( + 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == + Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == + Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 + ) + d = Tangent{Dict}(Dict(4 => 3.0)) + two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) + @test 2 * d == two_d == d * 2 + + @test_throws MethodError [1, 2] * Tangent{Foo}(; y=1.5, x=2.5) + @test_throws MethodError [1, 2] * d + @test_throws MethodError Tangent{Foo}(; y=1.5, x=2.5) * @thunk [1 2; 3 4] + end + + @testset "iszero" begin + @test iszero(Tangent{Foo}()) + @test iszero(Tangent{Tuple{}}()) + @test iszero(Tangent{Foo}(; x=ZeroTangent())) + @test iszero(Tangent{Foo}(; y=0.0)) + @test iszero(Tangent{Foo}(; x=Tangent{Tuple{}}(), y=0.0)) + + @test !iszero(Tangent{Foo}(; y=3.0)) + end + + @testset "show" begin + @test repr(Tangent{Foo}(; x=1)) == "Tangent{Foo}(x = 1,)" + # check for exact regex match not occurence( `^...$`) + # and allowing optional whitespace (`\s?`) + @test occursin( + r"^Tangent{Tuple{Int64,\s?Int64}}\(1,\s?2\)$", + repr(Tangent{Tuple{Int64,Int64}}(1, 2)), + ) + + @test repr(Tangent{Foo}()) == "Tangent{Foo}()" + + @test ==( + repr(MutableTangent{MFoo}((; x=1.5, y=[1.0, 2.0]))), + "MutableTangent{MFoo}(x = 1.5, y = [1.0, 2.0])", + ) + end + + @testset "internals" begin + @testset "Can't do backing on primative type" begin + @test_throws Exception ChainRulesCore.backing(1.4) + end + + @testset "Internals don't allocate a ton" begin + bk = (; x=1.0, y=2.0) + VERSION >= v"1.5" && + @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 + + # weaker version of the above (which should pass on all versions) + @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 + @test (@ballocated ChainRulesCore.elementwise_add($bk, $bk)) <= 48 + end + end + + @testset "non-same-typed differential arithmetic" begin + nt = (; a=1, b=2.0) + c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) + @test nt + c == (; a=1, b=2.1) + end + + @testset "printing" begin + t5 = Tuple(rand(3)) + nt3 = (x=t5, y=t5, z=nothing) + tang = ProjectTo(nt3)(nt3) # moderately complicated Tangent + @test contains(sprint(show, tang), "...}(x = Tangent") # gets shortened + @test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole + end + end + + @testset "MutableTangent" begin + mutable struct MDemo + x::Float64 + end + function ChainRulesCore.frule( + (_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x + ) + y = setfield!(obj, field, x) + ẏ = setproperty!(ȯbj, field, ẋ) + return y, ẏ + end + + @testset "usecase" begin + obj = MDemo(99.0) + ∂obj = MutableTangent{MDemo}(; x=1.5) + frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) + @test ∂obj.x == 10.0 + @test obj.x == 95.0 + + frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0) + @test ∂obj.x == 20.0 + @test getproperty(∂obj, 1) == 20.0 + @test obj.x == 96.0 + end + + @testset "== and hash" begin + @test MutableTangent{MDemo}(; x=1.0f0) == MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1.0f0) + @test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0) + + nt = (; x=1.0) + @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0) + + @test hash(MutableTangent{MDemo}(; x=1.0f0)) == hash(MutableTangent{MDemo}(; x=1.0)) + end + + @testset "Mutation" begin + v = MutableTangent{MFoo}(; x=1.5, y=2.4) + v.x = 1.6 + @test v == MutableTangent{MFoo}(; x=1.6, y=2.4) + v.y = [1.0, 2.0] # change type, because primal can change type + @test v == MutableTangent{MFoo}(; x=1.6, y=[1.0, 2.0]) + end + end + + @testset "map" begin + @testset "Tangent" begin + ∂foo = Tangent{Foo}(; x=1.5, y=2.4) + @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0, y=4.8) + + ∂foo = Tangent{Foo}(; x=1.5) + @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0) + end + @testset "MutableTangent" begin + ∂foo = MutableTangent{MFoo}(; x=1.5, y=2.4) + ∂foo2 = map(v -> 2 * v, ∂foo) + @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=4.8) + # Check can still be mutated to new typ + ∂foo2.y = [1.0, 2.0] + @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=[1.0, 2.0]) + end + end +end \ No newline at end of file diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/tangent.jl deleted file mode 100644 index b0cb5577e..000000000 --- a/test/tangent_types/tangent.jl +++ /dev/null @@ -1,427 +0,0 @@ -# For testing Tangent -struct Foo - x - y::Float64 -end - -# For testing Primal + Tangent performance -struct Bar - x::Float64 -end - -# For testing Tangent: it is an invarient of the type that x2 = 2x -# so simple addition can not be defined -struct StructWithInvariant - x - x2 - - StructWithInvariant(x) = new(x, 2x) -end - -@testset "Tangent" begin - @testset "empty types" begin - @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} - end - - @testset "constructor" begin - t = (1.0, 2.0) - nt = (x=1, y=2.0) - d = Dict(:x => 1.0, :y => 2.0) - vals = [1, 2] - - @test_throws ArgumentError Tangent{typeof(t),typeof(nt)}(nt) - @test_throws ArgumentError Tangent{typeof(t),typeof(d)}(d) - - @test_throws ArgumentError Tangent{typeof(d),typeof(nt)}(nt) - @test_throws ArgumentError Tangent{typeof(d),typeof(t)}(t) - - @test_throws ArgumentError Tangent{typeof(nt),typeof(vals)}(vals) - @test_throws ArgumentError Tangent{typeof(nt),typeof(d)}(d) - @test_throws ArgumentError Tangent{typeof(nt),typeof(t)}(t) - - @test_throws ArgumentError Tangent{Foo,typeof(d)}(d) - @test_throws ArgumentError Tangent{Foo,typeof(t)}(t) - end - - @testset "==" begin - @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) - @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) - @test Tangent{Foo}(; y=2.5, x=ZeroTangent()) == Tangent{Foo}(; y=2.5) - - @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) - @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) - - tup = (1.0, 2.0) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) - - @test Tangent{Foo}(; y=2.0) == Tangent{Foo}(; x=ZeroTangent(), y=Float32(2.0)) - end - - @testset "hash" begin - @test hash(Tangent{Foo}(; x=0.1, y=2.5)) == hash(Tangent{Foo}(; y=2.5, x=0.1)) - @test hash(Tangent{Foo}(; y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(; y=2.5)) - end - - @testset "indexing, iterating, and properties" begin - @test keys(Tangent{Foo}(; x=2.5)) == (:x,) - @test propertynames(Tangent{Foo}(; x=2.5)) == (:x,) - @test haskey(Tangent{Foo}(; x=2.5), :x) == true - if isdefined(Base, :hasproperty) - @test hasproperty(Tangent{Foo}(; x=2.5), :y) == false - end - @test Tangent{Foo}(; x=2.5).x == 2.5 - - tang1 = Tangent{Tuple{Float64}}(2.0) - @test keys(tang1) == Base.OneTo(1) - @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) - @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 - @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 - @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 - @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 - @test NoTangent() === @inferred Base.tail(tang1) - @test NoTangent() === @inferred Base.tail(Tangent{Tuple{}}()) - - tang3 = Tangent{Tuple{Float64, String, Vector{Float64}}}(1.0, NoTangent(), @thunk [3.0] .+ 4) - @test @inferred(first(tang3)) === tang3[1] === 1.0 - @test @inferred(last(tang3)) isa Thunk - @test unthunk(last(tang3)) == [7.0] - @test Tuple(@inferred Base.tail(tang3))[1] === NoTangent() - @test Tuple(Base.tail(tang3))[end] isa Thunk - - NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} - @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 - @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() - @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() - @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 - - @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 - @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() - @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() - @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 - - @test first(Tangent{NT}(; a=(@thunk 2.0^2))) isa Thunk - @test unthunk(first(Tangent{NT}(; a=(@thunk 2.0^2)))) == 4.0 - @test last(Tangent{NT}(; a=(@thunk 2.0^2))) isa ZeroTangent - - ntang1 = @inferred Base.tail(Tangent{NT}(; b=(@thunk 2.0^2))) - @test ntang1 isa Tangent{<:NamedTuple{(:b,)}} - @test NoTangent() === @inferred Base.tail(ntang1) - - # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 - # if VERSION >= v"1.8-" - # @test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true - # else - # @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true - # end - @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false - - @test length(Tangent{Foo}(; x=2.5)) == 1 - @test length(Tangent{Tuple{Float64}}(2.0)) == 1 - - @test eltype(Tangent{Foo}(; x=2.5)) == Float64 - @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 - - # Testing iterate via collect - @test collect(Tangent{Foo}(; x=2.5)) == [2.5] - @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] - - # Test indexed_iterate - ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) - _unpack2tuple = function (tangent) - a, b = tangent - return (a, b) - end - @inferred _unpack2tuple(ctup) - @test _unpack2tuple(ctup) === (2.0, 3) - - # Test getproperty is inferrable - _unpacknamedtuple = tangent -> (tangent.x, tangent.y) - if VERSION ≥ v"1.2" - @inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0)) - @inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0)) - end - end - - @testset "reverse" begin - c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") - cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) - @test reverse(c) === cr - - if VERSION < v"1.9-" - # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) - - d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{typeof(d),typeof(d)}(d) - @test_throws MethodError reverse(Tangent{Foo}()) - else - # These now work but do we care? - end - end - - @testset "unset properties" begin - @test Tangent{Foo}(; x=1.4).y === ZeroTangent() - end - - @testset "conj" begin - @test conj(Tangent{Foo}(; x=2.0 + 3.0im)) == Tangent{Foo}(; x=2.0 - 3.0im) - @test ==( - conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), Tangent{Tuple{Float64}}(2.0 - 3.0im) - ) - @test ==( - conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), - Tangent{Dict}(Dict(4 => 2.0 + -3.0im)), - ) - end - - @testset "canonicalize" begin - # Testing iterate via collect - @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) - - @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) - - # For structure it needs to match order and ZeroTangent() fill to match primal - CFoo = Tangent{Foo} - @test canonicalize(CFoo(; x=2.5, y=10)) == CFoo(; x=2.5, y=10) - @test canonicalize(CFoo(; y=10, x=2.5)) == CFoo(; x=2.5, y=10) - @test canonicalize(CFoo(; y=10)) == CFoo(; x=ZeroTangent(), y=10) - - @test_throws ArgumentError canonicalize(CFoo(; q=99.0, x=2.5)) - - @testset "unspecified primal type" begin - c1 = Tangent{Any}(; a=1, b=2) - c2 = Tangent{Any}(1, 2) - c3 = Tangent{Any}(Dict(4 => 3)) - - @test c1 == canonicalize(c1) - @test c2 == canonicalize(c2) - @test c3 == canonicalize(c3) - end - end - - @testset "+ with other composites" begin - @testset "Structs" begin - CFoo = Tangent{Foo} - @test CFoo(; x=1.5) + CFoo(; x=2.5) == CFoo(; x=4.0) - @test CFoo(; y=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=2.5) - @test CFoo(; y=1.5, x=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=4.0) - end - - @testset "Tuples" begin - @test ==( - typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), Tangent{Tuple{},Tuple{}} - ) - @test ( - Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + - Tangent{Tuple{Float64,Float64}}(1.0, 1.0) - ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) - end - - @testset "NamedTuples" begin - make_tangent(nt::NamedTuple) = Tangent{typeof(nt)}(; nt...) - t1 = make_tangent((; a=1.5, b=0.0)) - t2 = make_tangent((; a=0.0, b=2.5)) - t_sum = make_tangent((a=1.5, b=2.5)) - @test t1 + t2 == t_sum - end - - @testset "Dicts" begin - d1 = Tangent{Dict}(Dict(4 => 3.0, 3 => 2.0)) - d2 = Tangent{Dict}(Dict(4 => 3.0, 2 => 2.0)) - d_sum = Tangent{Dict}(Dict(4 => 3.0 + 3.0, 3 => 2.0, 2 => 2.0)) - @test d1 + d2 == d_sum - end - - @testset "Fields of type NotImplemented" begin - CFoo = Tangent{Foo} - a = CFoo(; x=1.5) - b = CFoo(; x=@not_implemented("")) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa CFoo - @test z.x isa ChainRulesCore.NotImplemented - end - - a = Tangent{Tuple}(1.5) - b = Tangent{Tuple}(@not_implemented("")) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa Tangent{Tuple} - @test first(z) isa ChainRulesCore.NotImplemented - end - - a = Tangent{NamedTuple{(:x,)}}(; x=1.5) - b = Tangent{NamedTuple{(:x,)}}(; x=@not_implemented("")) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa Tangent{NamedTuple{(:x,)}} - @test z.x isa ChainRulesCore.NotImplemented - end - - a = Tangent{Dict}(Dict(:x => 1.5)) - b = Tangent{Dict}(Dict(:x => @not_implemented(""))) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa Tangent{Dict} - @test z[:x] isa ChainRulesCore.NotImplemented - end - end - end - - @testset "+ with Primals" begin - @testset "Structs" begin - @test Foo(3.5, 1.5) + Tangent{Foo}(; x=2.5) == Foo(6.0, 1.5) - @test Tangent{Foo}(; x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) - @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 - end - - @testset "Tuples" begin - @test Tangent{Tuple{}}() + () == () - @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) - @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) - end - - @testset "NamedTuple" begin - ntx = (; a=1.5) - @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) - - nty = (; a=1.5, b=0.5) - @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) - end - - @testset "Dicts" begin - d_primal = Dict(4 => 3.0, 3 => 2.0) - d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) - @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) - end - end - - @testset "+ with Primals, with inner constructor" begin - value = StructWithInvariant(10.0) - diff = Tangent{StructWithInvariant}(; x=2.0, x2=6.0) - - @testset "with and without debug mode" begin - @assert ChainRulesCore.debug_mode() == false - @test_throws MethodError (value + diff) - @test_throws MethodError (diff + value) - - ChainRulesCore.debug_mode() = true # enable debug mode - @test_throws ChainRulesCore.PrimalAdditionFailedException (value + diff) - @test_throws ChainRulesCore.PrimalAdditionFailedException (diff + value) - ChainRulesCore.debug_mode() = false # disable it again - end - - # Now we define constuction for ChainRulesCore.jl's purposes: - # It is going to determine the root quanity of the invarient - function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) - x = (nt.x + nt.x2 / 2) / 2 - return StructWithInvariant(x) - end - @test value + diff == StructWithInvariant(12.5) - @test diff + value == StructWithInvariant(12.5) - end - - @testset "differential arithmetic" begin - c = Tangent{Foo}(; y=1.5, x=2.5) - - @test NoTangent() * c == NoTangent() - @test c * NoTangent() == NoTangent() - @test dot(NoTangent(), c) == NoTangent() - @test dot(c, NoTangent()) == NoTangent() - @test norm(Tangent{Foo}(; y=c.y, x=NoTangent())) == c.y - @test norm(NoTangent(), Inf) == 0 - - @test ZeroTangent() * c == ZeroTangent() - @test c * ZeroTangent() == ZeroTangent() - @test dot(ZeroTangent(), c) == ZeroTangent() - @test dot(c, ZeroTangent()) == ZeroTangent() - @test norm(ZeroTangent()) == 0 - @test norm(ZeroTangent(), 0.4) == 0 - - @test true * c === c - @test c * true === c - - t = @thunk 2 - @test t * c == 2 * c - @test c * t == c * 2 - end - - @testset "-Tangent" begin - t = Tangent{Foo}(; x=1.0, y=-2.0) - @test -t == Tangent{Foo}(; x=-1.0, y=2.0) - @test -1.0 * t == -t - end - - @testset "scaling" begin - @test ( - 2 * Tangent{Foo}(; y=1.5, x=2.5) == - Tangent{Foo}(; y=3.0, x=5.0) == - Tangent{Foo}(; y=1.5, x=2.5) * 2 - ) - @test ( - 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == - Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == - Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 - ) - d = Tangent{Dict}(Dict(4 => 3.0)) - two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) - @test 2 * d == two_d == d * 2 - - @test_throws MethodError [1, 2] * Tangent{Foo}(; y=1.5, x=2.5) - @test_throws MethodError [1, 2] * d - @test_throws MethodError Tangent{Foo}(; y=1.5, x=2.5) * @thunk [1 2; 3 4] - end - - @testset "iszero" begin - @test iszero(Tangent{Foo}()) - @test iszero(Tangent{Tuple{}}()) - @test iszero(Tangent{Foo}(; x=ZeroTangent())) - @test iszero(Tangent{Foo}(; y=0.0)) - @test iszero(Tangent{Foo}(; x=Tangent{Tuple{}}(), y=0.0)) - - @test !iszero(Tangent{Foo}(; y=3.0)) - end - - @testset "show" begin - @test repr(Tangent{Foo}(; x=1)) == "Tangent{Foo}(x = 1,)" - # check for exact regex match not occurence( `^...$`) - # and allowing optional whitespace (`\s?`) - @test occursin( - r"^Tangent{Tuple{Int64,\s?Int64}}\(1,\s?2\)$", - repr(Tangent{Tuple{Int64,Int64}}(1, 2)), - ) - - @test repr(Tangent{Foo}()) == "Tangent{Foo}()" - end - - @testset "internals" begin - @testset "Can't do backing on primative type" begin - @test_throws Exception ChainRulesCore.backing(1.4) - end - - @testset "Internals don't allocate a ton" begin - bk = (; x=1.0, y=2.0) - VERSION >= v"1.5" && - @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 - - # weaker version of the above (which should pass on all versions) - @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 - @test (@ballocated ChainRulesCore.elementwise_add($bk, $bk)) <= 48 - end - end - - @testset "non-same-typed differential arithmetic" begin - nt = (; a=1, b=2.0) - c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) - @test nt + c == (; a=1, b=2.1) - end - - @testset "printing" begin - t5 = Tuple(rand(3)) - nt3 = (x=t5, y=t5, z=nothing) - tang = ProjectTo(nt3)(nt3) # moderately complicated Tangent - @test contains(sprint(show, tang), "...}(x = Tangent") # gets shortened - @test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole - end -end