From 7f15d119ba68bcae3f106a3e1e90f696bd374623 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 4 Aug 2023 04:11:29 -0400 Subject: [PATCH 01/34] rename files --- src/ChainRulesCore.jl | 2 +- src/tangent_types/{tangent.jl => structural_tangent.jl} | 0 test/runtests.jl | 2 +- test/tangent_types/{tangent.jl => structural_tangent.jl} | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename src/tangent_types/{tangent.jl => structural_tangent.jl} (100%) rename test/tangent_types/{tangent.jl => structural_tangent.jl} (100%) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 94e8242b1..f943c50fa 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -21,7 +21,7 @@ include("debug_mode.jl") include("tangent_types/abstract_tangent.jl") include("tangent_types/abstract_zero.jl") include("tangent_types/thunks.jl") -include("tangent_types/tangent.jl") +include("tangent_types/structural_tangent.jl") include("tangent_types/notimplemented.jl") include("tangent_arithmetic.jl") diff --git a/src/tangent_types/tangent.jl b/src/tangent_types/structural_tangent.jl similarity index 100% rename from src/tangent_types/tangent.jl rename to src/tangent_types/structural_tangent.jl diff --git a/test/runtests.jl b/test/runtests.jl index 6a4684d03..a3b0971a5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,7 +11,7 @@ using Test @testset "differentials" begin include("tangent_types/abstract_zero.jl") include("tangent_types/thunks.jl") - include("tangent_types/tangent.jl") + include("tangent_types/structural_tangent.jl") include("tangent_types/notimplemented.jl") end diff --git a/test/tangent_types/tangent.jl b/test/tangent_types/structural_tangent.jl similarity index 100% rename from test/tangent_types/tangent.jl rename to test/tangent_types/structural_tangent.jl From c9b49386e9d0562d5301d33372499fa6e2def93d Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 4 Aug 2023 06:35:40 -0400 Subject: [PATCH 02/34] move functionality up to StructuralTangent --- src/tangent_arithmetic.jl | 22 +- src/tangent_types/structural_tangent.jl | 369 +++++++++++++----------- 2 files changed, 211 insertions(+), 180 deletions(-) diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 439f0ac8f..18ae7b3ad 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -20,7 +20,7 @@ Base.:+(x::NotImplemented, ::NotImplemented) = x Base.:*(x::NotImplemented, ::NotImplemented) = x LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x # `NotImplemented` always "wins" + -for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :Tangent, :Any) +for T in (:ZeroTangent, :NoTangent, :AbstractThunk, :StructuralTangent, :Any) @eval Base.:+(x::NotImplemented, ::$T) = x @eval Base.:+(::$T, x::NotImplemented) = x end @@ -33,7 +33,7 @@ for T in (:ZeroTangent, :NoTangent) @eval LinearAlgebra.dot(::$T, ::NotImplemented) = $T() end # `NotImplemented` "wins" * and dot for other types -for T in (:AbstractThunk, :Tangent, :Any) +for T in (:AbstractThunk, :StructuralTangent, :Any) @eval Base.:*(x::NotImplemented, ::$T) = x @eval Base.:*(::$T, x::NotImplemented) = x @eval LinearAlgebra.dot(x::NotImplemented, ::$T) = x @@ -55,7 +55,7 @@ Base.:-(::NoTangent, ::NoTangent) = NoTangent() Base.:-(::NoTangent) = NoTangent() Base.:*(::NoTangent, ::NoTangent) = NoTangent() LinearAlgebra.dot(::NoTangent, ::NoTangent) = NoTangent() -for T in (:AbstractThunk, :Tangent, :Any) +for T in (:AbstractThunk, :StructuralTangent, :Any) @eval Base.:+(::NoTangent, b::$T) = b @eval Base.:+(a::$T, ::NoTangent) = a @eval Base.:-(::NoTangent, b::$T) = -b @@ -95,7 +95,7 @@ Base.:-(::ZeroTangent, ::ZeroTangent) = ZeroTangent() Base.:-(::ZeroTangent) = ZeroTangent() Base.:*(::ZeroTangent, ::ZeroTangent) = ZeroTangent() LinearAlgebra.dot(::ZeroTangent, ::ZeroTangent) = ZeroTangent() -for T in (:AbstractThunk, :Tangent, :Any) +for T in (:AbstractThunk, :StructuralTangent, :Any) @eval Base.:+(::ZeroTangent, b::$T) = b @eval Base.:+(a::$T, ::ZeroTangent) = a @eval Base.:-(::ZeroTangent, b::$T) = -b @@ -126,11 +126,11 @@ for T in (:Tangent, :Any) @eval Base.:*(a::$T, b::AbstractThunk) = a * unthunk(b) end -function Base.:+(a::Tangent{P}, b::Tangent{P}) where {P} +function Base.:+(a::StructuralTangent{P}, b::StructuralTangent{P}) where {P} data = elementwise_add(backing(a), backing(b)) - return Tangent{P,typeof(data)}(data) + return StructuralTangent{P}(data) end -function Base.:+(a::P, d::Tangent{P}) where {P} +function Base.:+(a::P, d::StructuralTangent{P}) where {P} net_backing = elementwise_add(backing(a), backing(d)) if debug_mode() try @@ -143,14 +143,14 @@ function Base.:+(a::P, d::Tangent{P}) where {P} end end Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d)) -Base.:+(a::Tangent{P}, b::P) where {P} = b + a +Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a -Base.:-(tangent::Tangent{P}) where {P} = map(-, tangent) +Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent) # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful # In general one doesn't have to represent multiplications of 2 tangents # Only of a tangent and a scaling factor (generally `Real`) for T in (:Number,) - @eval Base.:*(s::$T, tangent::Tangent) = map(x -> s * x, tangent) - @eval Base.:*(tangent::Tangent, s::$T) = map(x -> x * s, tangent) + @eval Base.:*(s::$T, tangent::StructuralTangent) = map(x -> s * x, tangent) + @eval Base.:*(tangent::StructuralTangent, s::$T) = map(x -> x * s, tangent) end diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 6af968c53..a8fd4ac1b 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -1,5 +1,201 @@ """ - Tangent{P, T} <: AbstractTangent + StructuralTangent{P} <: AbstractTangent + +Representing the type of the tangent of a `struct` `P` (or a `Tuple`/`NamedTuple`). +as an object with mirroring fields. +""" +abstract type StructuralTangent{P} <: AbstractTangent end + +function StructuralTangent{P}(nt::NamedTuple) where P + return Tangent{P, typeof(nt)}(nt) +end + +StructuralTangent{P}(tup::Tuple) where P = Tangent{P, typeof(tup)}(tup) +StructuralTangent{P}(dict::Dict) where P = Tangent{P}(dict) + + +Base.keys(tangent::StructuralTangent) = keys(backing(tangent)) +Base.propertynames(tangent::StructuralTangent) = propertynames(backing(tangent)) + +Base.haskey(tangent::StructuralTangent, key) = haskey(backing(tangent), key) +if isdefined(Base, :hasproperty) + Base.hasproperty(tangent::StructuralTangent, key::Symbol) = hasproperty(backing(tangent), key) +end + +Base.iszero(t::StructuralTangent) = all(iszero, backing(t)) + +function Base.map(f, tangent::StructuralTangent{P}) where {P} + L = propertynames(backing(tangent)) + vals = map(f, Tuple(backing(tangent))) + named_vals = NamedTuple{L,typeof(vals)}(vals) + return if tangent isa Tangent + Tangent{P, typeof(named_vals)}(named_vals) + else + # Handle MutableTangent + end +end + + +""" + backing(x) + +Accesses the backing field of a `Tangent`, +or destructures any other struct type into a `NamedTuple`. +Identity function on `Tuple`s and `NamedTuple`s. + +This is an internal function used to simplify operations between `Tangent`s and the +primal types. +""" +backing(x::Tuple) = x +backing(x::NamedTuple) = x +backing(x::Dict) = x +backing(x::StructuralTangent) = getfield(x, :backing) + +# For generic structs +function backing(x::T)::NamedTuple where {T} + # note: all computation outside the if @generated happens at runtime. + # so the first 4 lines of the branchs look the same, but can not be moved out. + # see https://github.com/JuliaLang/julia/issues/34283 + if @generated + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) + nfields = fieldcount(T) + names = fieldnames(T) + types = fieldtypes(T) + + vals = Expr(:tuple, ntuple(ii -> :(getfield(x, $ii)), nfields)...) + return :(NamedTuple{$names,Tuple{$(types...)}}($vals)) + else + !isstructtype(T) && + throw(DomainError(T, "backing can only be used on struct types")) + nfields = fieldcount(T) + names = fieldnames(T) + types = fieldtypes(T) + + vals = ntuple(ii -> getfield(x, ii), nfields) + return NamedTuple{names,Tuple{types...}}(vals) + end +end + + +""" + _zeroed_backing(P) + +Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`. +""" +@generated function _zeroed_backing(::Type{P}) where {P} + nil_base = ntuple(fieldcount(P)) do i + (fieldname(P, i), ZeroTangent()) + end + return (; nil_base...) +end + +""" + construct(::Type{T}, fields::[NamedTuple|Tuple]) + +Constructs an object of type `T`, with the given fields. +Fields must be correct in name and type, and `T` must have a default constructor. + +This internally is called to construct structs of the primal type `T`, +after an operation such as the addition of a primal to a tangent + +It should be overloaded, if `T` does not have a default constructor, +or if `T` needs to maintain some invarients between its fields. +""" +function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} + # Tested and verified that that this avoids a ton of allocations + if length(L) !== fieldcount(T) + # if length is equal but names differ then we will catch that below anyway. + throw(ArgumentError("Unmatched fields. Type: $(fieldnames(T)), NamedTuple: $L")) + end + + if @generated + vals = (:(getproperty(fields, $(QuoteNode(fname)))) for fname in fieldnames(T)) + return :(T($(vals...))) + else + return T((getproperty(fields, fname) for fname in fieldnames(T))...) + end +end + +construct(::Type{T}, fields::T) where {T<:NamedTuple} = fields +construct(::Type{T}, fields::T) where {T<:Tuple} = fields + +elementwise_add(a::Tuple, b::Tuple) = map(+, a, b) + +function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} + # Rule of Tangent addition: any fields not present are implict hard Zeros + + # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base. + # https://github.com/JuliaLang/julia/blob/592748adb25301a45bd6edef3ac0a93eed069852/base/namedtuple.jl#L220-L231 + if @generated + names = Base.merge_names(an, bn) + + vals = map(names) do field + a_field = :(getproperty(a, $(QuoteNode(field)))) + b_field = :(getproperty(b, $(QuoteNode(field)))) + value_expr = if Base.sym_in(field, an) + if Base.sym_in(field, bn) + # in both + :($a_field + $b_field) + else + # only in `an` + a_field + end + else # must be in `b` only + b_field + end + Expr(:kw, field, value_expr) + end + return Expr(:tuple, Expr(:parameters, vals...)) + else + names = Base.merge_names(an, bn) + vals = map(names) do field + value = if Base.sym_in(field, an) + a_field = getproperty(a, field) + if Base.sym_in(field, bn) + # in both + b_field = getproperty(b, field) + a_field + b_field + else + # only in `an` + a_field + end + else # must be in `b` only + getproperty(b, field) + end + field => value + end + return (; vals...) + end +end + +elementwise_add(a::Dict, b::Dict) = merge(+, a, b) + +struct PrimalAdditionFailedException{P} <: Exception + primal::P + tangent + original::Exception +end + +function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} + println(io, "Could not construct $P after addition.") + println(io, "This probably means no default constructor is defined.") + println(io, "Either define a default constructor") + printstyled(io, "$P(", join(propertynames(err.tangent), ", "), ")"; color=:blue) + println(io, "\nor overload") + printstyled( + io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.tangent)))"; color=:blue + ) + println(io, "\nor overload") + printstyled(io, "Base.:+(::$P, ::$(typeof(err.tangent)))"; color=:blue) + println(io, "\nOriginal Exception:") + printstyled(io, err.original; color=:yellow) + return println(io) +end + + +""" + Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`. `P` is the the corresponding primal type that this is a tangent for. @@ -21,7 +217,7 @@ Any fields not explictly present in the `Tangent` are treated as being set to `Z To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) function is provided. """ -struct Tangent{P,T} <: AbstractTangent +struct Tangent{P,T} <: StructuralTangent{P} # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict # (but potentially a different one, as it doesn't contain tangents) backing::T @@ -62,6 +258,7 @@ function _backing_error(P, G, E) return throw(ArgumentError(msg)) end + function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} return backing(a) == backing(b) end @@ -98,7 +295,7 @@ end Base.iszero(::Tangent{<:,NamedTuple{}}) = true Base.iszero(::Tangent{<:,Tuple{}}) = true -Base.iszero(t::Tangent) = all(iszero, backing(t)) + Base.first(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = first(backing(canonicalize(tangent))) Base.last(tangent::Tangent{P,T}) where {P,T<:Union{Tuple,NamedTuple}} = last(backing(canonicalize(tangent))) @@ -134,13 +331,6 @@ function Base.getproperty(tangent::Tangent{P,T}, idx::Symbol) where {P,T<:NamedT return unthunk(getfield(backing(tangent), idx)) end -Base.keys(tangent::Tangent) = keys(backing(tangent)) -Base.propertynames(tangent::Tangent) = propertynames(backing(tangent)) - -Base.haskey(tangent::Tangent, key) = haskey(backing(tangent), key) -if isdefined(Base, :hasproperty) - Base.hasproperty(tangent::Tangent, key::Symbol) = hasproperty(backing(tangent), key) -end Base.iterate(tangent::Tangent, args...) = iterate(backing(tangent), args...) Base.length(tangent::Tangent) = length(backing(tangent)) @@ -159,57 +349,13 @@ function Base.map(f, tangent::Tangent{P,<:Tuple}) where {P} vals::Tuple = map(f, backing(tangent)) return Tangent{P,typeof(vals)}(vals) end -function Base.map(f, tangent::Tangent{P,<:NamedTuple{L}}) where {P,L} - vals = map(f, Tuple(backing(tangent))) - named_vals = NamedTuple{L,typeof(vals)}(vals) - return Tangent{P,typeof(named_vals)}(named_vals) -end function Base.map(f, tangent::Tangent{P,<:Dict}) where {P<:Dict} return Tangent{P}(Dict(k => f(v) for (k, v) in backing(tangent))) end Base.conj(tangent::Tangent) = map(conj, tangent) -""" - backing(x) - -Accesses the backing field of a `Tangent`, -or destructures any other struct type into a `NamedTuple`. -Identity function on `Tuple`s and `NamedTuple`s. - -This is an internal function used to simplify operations between `Tangent`s and the -primal types. -""" -backing(x::Tuple) = x -backing(x::NamedTuple) = x -backing(x::Dict) = x -backing(x::Tangent) = getfield(x, :backing) - -# For generic structs -function backing(x::T)::NamedTuple where {T} - # note: all computation outside the if @generated happens at runtime. - # so the first 4 lines of the branchs look the same, but can not be moved out. - # see https://github.com/JuliaLang/julia/issues/34283 - if @generated - !isstructtype(T) && - throw(DomainError(T, "backing can only be used on struct types")) - nfields = fieldcount(T) - names = fieldnames(T) - types = fieldtypes(T) - - vals = Expr(:tuple, ntuple(ii -> :(getfield(x, $ii)), nfields)...) - return :(NamedTuple{$names,Tuple{$(types...)}}($vals)) - else - !isstructtype(T) && - throw(DomainError(T, "backing can only be used on struct types")) - nfields = fieldcount(T) - names = fieldnames(T) - types = fieldtypes(T) - vals = ntuple(ii -> getfield(x, ii), nfields) - return NamedTuple{names,Tuple{types...}}(vals) - end -end """ canonicalize(tangent::Tangent{P}) -> Tangent{P} @@ -243,118 +389,3 @@ canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent canonicalize(tangent::Tangent{Any,<:Tuple}) = tangent canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent - -""" - _zeroed_backing(P) - -Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`. -""" -@generated function _zeroed_backing(::Type{P}) where {P} - nil_base = ntuple(fieldcount(P)) do i - (fieldname(P, i), ZeroTangent()) - end - return (; nil_base...) -end - -""" - construct(::Type{T}, fields::[NamedTuple|Tuple]) - -Constructs an object of type `T`, with the given fields. -Fields must be correct in name and type, and `T` must have a default constructor. - -This internally is called to construct structs of the primal type `T`, -after an operation such as the addition of a primal to a tangent - -It should be overloaded, if `T` does not have a default constructor, -or if `T` needs to maintain some invarients between its fields. -""" -function construct(::Type{T}, fields::NamedTuple{L}) where {T,L} - # Tested and verified that that this avoids a ton of allocations - if length(L) !== fieldcount(T) - # if length is equal but names differ then we will catch that below anyway. - throw(ArgumentError("Unmatched fields. Type: $(fieldnames(T)), NamedTuple: $L")) - end - - if @generated - vals = (:(getproperty(fields, $(QuoteNode(fname)))) for fname in fieldnames(T)) - return :(T($(vals...))) - else - return T((getproperty(fields, fname) for fname in fieldnames(T))...) - end -end - -construct(::Type{T}, fields::T) where {T<:NamedTuple} = fields -construct(::Type{T}, fields::T) where {T<:Tuple} = fields - -elementwise_add(a::Tuple, b::Tuple) = map(+, a, b) - -function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn} - # Rule of Tangent addition: any fields not present are implict hard Zeros - - # Base on the `merge(:;NamedTuple, ::NamedTuple)` code from Base. - # https://github.com/JuliaLang/julia/blob/592748adb25301a45bd6edef3ac0a93eed069852/base/namedtuple.jl#L220-L231 - if @generated - names = Base.merge_names(an, bn) - - vals = map(names) do field - a_field = :(getproperty(a, $(QuoteNode(field)))) - b_field = :(getproperty(b, $(QuoteNode(field)))) - value_expr = if Base.sym_in(field, an) - if Base.sym_in(field, bn) - # in both - :($a_field + $b_field) - else - # only in `an` - a_field - end - else # must be in `b` only - b_field - end - Expr(:kw, field, value_expr) - end - return Expr(:tuple, Expr(:parameters, vals...)) - else - names = Base.merge_names(an, bn) - vals = map(names) do field - value = if Base.sym_in(field, an) - a_field = getproperty(a, field) - if Base.sym_in(field, bn) - # in both - b_field = getproperty(b, field) - a_field + b_field - else - # only in `an` - a_field - end - else # must be in `b` only - getproperty(b, field) - end - field => value - end - return (; vals...) - end -end - -elementwise_add(a::Dict, b::Dict) = merge(+, a, b) - -struct PrimalAdditionFailedException{P} <: Exception - primal::P - tangent::Tangent{P} - original::Exception -end - -function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} - println(io, "Could not construct $P after addition.") - println(io, "This probably means no default constructor is defined.") - println(io, "Either define a default constructor") - printstyled(io, "$P(", join(propertynames(err.tangent), ", "), ")"; color=:blue) - println(io, "\nor overload") - printstyled( - io, "ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.tangent)))"; color=:blue - ) - println(io, "\nor overload") - printstyled(io, "Base.:+(::$P, ::$(typeof(err.tangent)))"; color=:blue) - println(io, "\nOriginal Exception:") - printstyled(io, err.original; color=:yellow) - return println(io) -end From a51e51e8a7be5a5996105bb313f44d342601a42d Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 4 Aug 2023 07:00:51 -0400 Subject: [PATCH 03/34] Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/structural_tangent.jl | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index a8fd4ac1b..9ea665735 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -6,20 +6,21 @@ as an object with mirroring fields. """ abstract type StructuralTangent{P} <: AbstractTangent end -function StructuralTangent{P}(nt::NamedTuple) where P - return Tangent{P, typeof(nt)}(nt) +function StructuralTangent{P}(nt::NamedTuple) where {P} + return Tangent{P,typeof(nt)}(nt) end -StructuralTangent{P}(tup::Tuple) where P = Tangent{P, typeof(tup)}(tup) -StructuralTangent{P}(dict::Dict) where P = Tangent{P}(dict) - +StructuralTangent{P}(tup::Tuple) where {P} = Tangent{P,typeof(tup)}(tup) +StructuralTangent{P}(dict::Dict) where {P} = Tangent{P}(dict) Base.keys(tangent::StructuralTangent) = keys(backing(tangent)) Base.propertynames(tangent::StructuralTangent) = propertynames(backing(tangent)) Base.haskey(tangent::StructuralTangent, key) = haskey(backing(tangent), key) if isdefined(Base, :hasproperty) - Base.hasproperty(tangent::StructuralTangent, key::Symbol) = hasproperty(backing(tangent), key) + function Base.hasproperty(tangent::StructuralTangent, key::Symbol) + return hasproperty(backing(tangent), key) + end end Base.iszero(t::StructuralTangent) = all(iszero, backing(t)) @@ -29,13 +30,12 @@ function Base.map(f, tangent::StructuralTangent{P}) where {P} vals = map(f, Tuple(backing(tangent))) named_vals = NamedTuple{L,typeof(vals)}(vals) return if tangent isa Tangent - Tangent{P, typeof(named_vals)}(named_vals) + Tangent{P,typeof(named_vals)}(named_vals) else # Handle MutableTangent end end - """ backing(x) @@ -77,7 +77,6 @@ function backing(x::T)::NamedTuple where {T} end end - """ _zeroed_backing(P) @@ -193,7 +192,6 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} return println(io) end - """ Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent @@ -258,7 +256,6 @@ function _backing_error(P, G, E) return throw(ArgumentError(msg)) end - function Base.:(==)(a::Tangent{P,T}, b::Tangent{P,T}) where {P,T} return backing(a) == backing(b) end From 93af90b2e5530598974c2b4d127f25933e286230 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 21 Aug 2023 17:46:22 +0800 Subject: [PATCH 04/34] WIP mutable Tangent (squash me) --- src/tangent_types/structural_tangent.jl | 49 ++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 9ea665735..9410e50af 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -3,15 +3,31 @@ Representing the type of the tangent of a `struct` `P` (or a `Tuple`/`NamedTuple`). as an object with mirroring fields. + +!!!!!! warning Exprimental + The `StructuralTangent` constructor returns a `MutableTangent` for mutable structs. + `MutableTangent` is an experimental feature. + Thus use of `StructuralTangent` (rather than `Tangent` directly) is also experimental. + While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. + """ abstract type StructuralTangent{P} <: AbstractTangent end function StructuralTangent{P}(nt::NamedTuple) where {P} - return Tangent{P,typeof(nt)}(nt) + if ismutabletype(P) + return MutableTangent{P}(nt) + else + return Tangent{P,typeof(nt)}(nt) + end end -StructuralTangent{P}(tup::Tuple) where {P} = Tangent{P,typeof(tup)}(tup) -StructuralTangent{P}(dict::Dict) where {P} = Tangent{P}(dict) +ismutabletype(::Type{P}) where P = ismutable(P) +ismutabletype(::Type{String}) = false +ismutabletype(::Type{Symbol}) = false + + +StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup) +StructuralTangent{P}(dict::Dict) where P = Tangent{P}(dict) Base.keys(tangent::StructuralTangent) = keys(backing(tangent)) Base.propertynames(tangent::StructuralTangent) = propertynames(backing(tangent)) @@ -29,10 +45,10 @@ function Base.map(f, tangent::StructuralTangent{P}) where {P} L = propertynames(backing(tangent)) vals = map(f, Tuple(backing(tangent))) named_vals = NamedTuple{L,typeof(vals)}(vals) - return if tangent isa Tangent - Tangent{P,typeof(named_vals)}(named_vals) + return if tangent isa MutableTangent + MutableTangent{P}(named_vals) else - # Handle MutableTangent + Tangent{P,typeof(named_vals)}(named_vals) end end @@ -386,3 +402,24 @@ canonicalize(tangent::Tangent{<:Any,<:AbstractDict}) = tangent canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent canonicalize(tangent::Tangent{Any,<:Tuple}) = tangent canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent + + +""" + MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent + +This type represents the tangent to a mutable struct. +It itself is also mutable. + +!!!!!! warning Exprimental + MutableTangent is an experimental feature. + While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. + Exactly how it should be used (e.g. is it forward-mode only?) +""" +mutable struct MutableTangent{P} + # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict + # (but potentially a different one, as it doesn't contain tangents) + backing::NamedTuple +end + +Base.getproperty(tangent::MutableTangent, idx::Symbol) = unthunk(getfield(backing(tangent), idx)) +Base.setproperty! \ No newline at end of file From 418b5cece96e95e0393f86dd45b6534a321bd1b6 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 15 Sep 2023 17:18:21 +0800 Subject: [PATCH 05/34] wip --- src/tangent_types/structural_tangent.jl | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 9410e50af..7df7008db 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -14,16 +14,14 @@ as an object with mirroring fields. abstract type StructuralTangent{P} <: AbstractTangent end function StructuralTangent{P}(nt::NamedTuple) where {P} - if ismutabletype(P) + if has_mutable_tangent(P) return MutableTangent{P}(nt) else return Tangent{P,typeof(nt)}(nt) end end -ismutabletype(::Type{P}) where P = ismutable(P) -ismutabletype(::Type{String}) = false -ismutabletype(::Type{Symbol}) = false +has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(T) > 0) StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup) @@ -410,16 +408,22 @@ canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent This type represents the tangent to a mutable struct. It itself is also mutable. -!!!!!! warning Exprimental +!!! warning Exprimental MutableTangent is an experimental feature. While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. Exactly how it should be used (e.g. is it forward-mode only?) + +!!! warning Do not directly mess with the tangent backing data + It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values. + However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is). + If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this. """ mutable struct MutableTangent{P} - # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict - # (but potentially a different one, as it doesn't contain tangents) - backing::NamedTuple + #TODO: we may want to absolutely lock the type of this down + backing::NamedTuple end Base.getproperty(tangent::MutableTangent, idx::Symbol) = unthunk(getfield(backing(tangent), idx)) -Base.setproperty! \ No newline at end of file +function Base.setproperty!(tangent::MutableTangent, name::Symbol, x) + new_backing = Base.setindex(backing(tangent), x, name) +end \ No newline at end of file From 724ba1b812f8641361285ca27eefcc9b2ecaf642 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 15 Sep 2023 20:23:27 +0800 Subject: [PATCH 06/34] First pass at something that maybe works --- src/ChainRulesCore.jl | 2 +- src/tangent_types/structural_tangent.jl | 7 +++++-- test/tangent_types/structural_tangent.jl | 18 ++++++++++++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index f943c50fa..4b86570cd 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -14,7 +14,7 @@ export ProjectTo, canonicalize, unthunk # tangent operations export add!!, is_inplaceable_destination # gradient accumulation operations export ignore_derivatives, @ignore_derivatives # tangents -export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk +export Tangent, MutableTangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk include("debug_mode.jl") diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 7df7008db..18e40fea9 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -418,12 +418,15 @@ It itself is also mutable. However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is). If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this. """ -mutable struct MutableTangent{P} +mutable struct MutableTangent{P} <: StructuralTangent{P} #TODO: we may want to absolutely lock the type of this down backing::NamedTuple end -Base.getproperty(tangent::MutableTangent, idx::Symbol) = unthunk(getfield(backing(tangent), idx)) +MutableTangent{P}(;kwargs...) where P = MutableTangent{P}(NamedTuple(kwargs)) +Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(backing(tangent), idx) function Base.setproperty!(tangent::MutableTangent, name::Symbol, x) new_backing = Base.setindex(backing(tangent), x, name) + setfield!(tangent, :backing, new_backing) + return x end \ No newline at end of file diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index b0cb5577e..671d2fad0 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -425,3 +425,21 @@ end @test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole end end + +@testset "MutableTangent" begin + mutable struct MDemo + x::Float64 + end + function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x) + y = setfield!(obj, field, x) + ẏ = setproperty!(ȯbj, field, ẋ) + return y, ẏ + end + + obj = MDemo(99.0) + ∂obj = MutableTangent{MDemo}(;x=1.5) + frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) + + @test ∂obj.x == 10.0 + @test obj.x == 95.0 +end \ No newline at end of file From c9a65df124e3e6cccf825e1d40b81ae44dffa5f7 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 18 Sep 2023 13:38:22 +0800 Subject: [PATCH 07/34] accept int index --- src/tangent_types/structural_tangent.jl | 13 ++++++++++++- test/tangent_types/structural_tangent.jl | 6 +++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 18e40fea9..f26c39874 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -424,9 +424,20 @@ mutable struct MutableTangent{P} <: StructuralTangent{P} end MutableTangent{P}(;kwargs...) where P = MutableTangent{P}(NamedTuple(kwargs)) + Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(backing(tangent), idx) +Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(backing(tangent), idx) # break ambig + function Base.setproperty!(tangent::MutableTangent, name::Symbol, x) new_backing = Base.setindex(backing(tangent), x, name) setfield!(tangent, :backing, new_backing) return x -end \ No newline at end of file +end + +function Base.setproperty!(tangent::MutableTangent, idx::Int, x) + # needed due to https://github.com/JuliaLang/julia/issues/43155 + name = idx2sym(backing(tangent), idx) + return setproperty!(tangent, name, x) +end + +idx2sym(::NamedTuple{names}, idx) where names = names[idx] \ No newline at end of file diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 671d2fad0..03e4db45e 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -439,7 +439,11 @@ end obj = MDemo(99.0) ∂obj = MutableTangent{MDemo}(;x=1.5) frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) - @test ∂obj.x == 10.0 @test obj.x == 95.0 + + frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0) + @test ∂obj.x == 20.0 + @test getproperty(∂obj, 1) == 20.0 + @test obj.x == 96.0 end \ No newline at end of file From 06b51f5d38289cf5f330a068eb423e549b5044cc Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 26 Sep 2023 20:20:56 +0800 Subject: [PATCH 08/34] add == and hash for MutableTangent --- src/tangent_types/structural_tangent.jl | 11 ++++++-- test/tangent_types/structural_tangent.jl | 34 +++++++++++++++++------- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index f26c39874..dd7a53ec6 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -5,6 +5,7 @@ Representing the type of the tangent of a `struct` `P` (or a `Tuple`/`NamedTuple as an object with mirroring fields. !!!!!! warning Exprimental + `StructuralTangent` is an experimental feature, and is part of the mutation support featureset. The `StructuralTangent` constructor returns a `MutableTangent` for mutable structs. `MutableTangent` is an experimental feature. Thus use of `StructuralTangent` (rather than `Tangent` directly) is also experimental. @@ -409,7 +410,7 @@ This type represents the tangent to a mutable struct. It itself is also mutable. !!! warning Exprimental - MutableTangent is an experimental feature. + MutableTangent is an experimental feature, and is part of the mutation support featureset. While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. Exactly how it should be used (e.g. is it forward-mode only?) @@ -440,4 +441,10 @@ function Base.setproperty!(tangent::MutableTangent, idx::Int, x) return setproperty!(tangent, name, x) end -idx2sym(::NamedTuple{names}, idx) where names = names[idx] \ No newline at end of file +idx2sym(::NamedTuple{names}, idx) where names = names[idx] + +Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h) +function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2} + typeintersect(T1, T2) == Union{} && return false + backing(t1)==backing(t2) +end \ No newline at end of file diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 03e4db45e..5736b2f1d 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -436,14 +436,28 @@ end return y, ẏ end - obj = MDemo(99.0) - ∂obj = MutableTangent{MDemo}(;x=1.5) - frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) - @test ∂obj.x == 10.0 - @test obj.x == 95.0 - - frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0) - @test ∂obj.x == 20.0 - @test getproperty(∂obj, 1) == 20.0 - @test obj.x == 96.0 + @testset "usecase" begin + obj = MDemo(99.0) + ∂obj = MutableTangent{MDemo}(;x=1.5) + frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) + @test ∂obj.x == 10.0 + @test obj.x == 95.0 + + frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0) + @test ∂obj.x == 20.0 + @test getproperty(∂obj, 1) == 20.0 + @test obj.x == 96.0 + end + + @testset "== and hash" begin + @test MutableTangent{Any}(x=1.0) == MutableTangent{MDemo}(x=1.0) + @test MutableTangent{MDemo}(x=1.0) == MutableTangent{Any}(x=1.0) + @test MutableTangent{Any}(x=2.0) != MutableTangent{MDemo}(x=1.0) + @test MutableTangent{MDemo}(x=1.0) != MutableTangent{Any}(x=2.0) + + nt = (;x=1.0) + @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(x=1.0) + + @test hash(MutableTangent{Any}(x=1.0)) == hash(MutableTangent{MDemo}(x=1.0)) + end end \ No newline at end of file From ed3aa1da4ed681332aeeee8c2cbbce86be6ddaba Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 26 Sep 2023 21:44:01 +0800 Subject: [PATCH 09/34] add and test zero_tangent --- src/ChainRulesCore.jl | 2 +- src/tangent_types/abstract_zero.jl | 29 +++++++++++++++++++++++++ src/tangent_types/structural_tangent.jl | 2 +- test/tangent_types/abstract_zero.jl | 14 ++++++++++++ 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 4b86570cd..51db59b64 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -10,7 +10,7 @@ export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMod export frule_via_ad, rrule_via_ad # definition helper macros export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented -export ProjectTo, canonicalize, unthunk # tangent operations +export ProjectTo, canonicalize, unthunk, zero_tangent # tangent operations export add!!, is_inplaceable_destination # gradient accumulation operations export ignore_derivatives, @ignore_derivatives # tangents diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 77c455c04..dd9068b74 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -91,3 +91,32 @@ arguments. ``` """ struct NoTangent <: AbstractZero end + +""" + zero_tangent(primal) + +This returns an appropriate zero tangent suitable for accumulating tangents of the primal. +For mutable composites types this is a structural []`MutableTangent`](@ref) +For `Array`s, it is applied recursively for each element. +For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is default out-of-place for contexts where mutation does not apply. +(Where mutation is not to be supported even for mutable types, then [`ZeroTangent()`](@ref) should be used for everything) + +!!! warning Exprimental + `zero_tangent`is an experimental feature, and is part of the mutation support featureset. + While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. + Exactly how it should be used (e.g. is it forward-mode only?) +""" +function zero_tangent end +zero_tangent(::AbstractString) = ZeroTangent() +# zero_tangent(::Number) = zero(x) # TODO: do we want this? +zero_tangent(primal::Array{<:Number}) = zero(primal) # TODO: do we want this? +zero_tangent(primal::Array) = map(zero_tangent, primal) +@generated function zero_tangent(primal) + has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples + zfield_exprs = map(fieldnames(primal)) do fname + fval = Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname))) + Expr(:kw, fname, fval) + end + backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...)) + return :($MutableTangent{$primal}($backing_expr)) +end \ No newline at end of file diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index dd7a53ec6..864029572 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -22,7 +22,7 @@ function StructuralTangent{P}(nt::NamedTuple) where {P} end end -has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(T) > 0) +has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(P) > 0) StructuralTangent{P}(tup::Tuple) where P = Tangent{P,typeof(tup)}(tup) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 028d942ea..be2c74241 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -160,3 +160,17 @@ @test isempty(detect_ambiguities(M)) end end + +@testset "zero_tangent" begin + mutable struct MutDemo + x::Float64 + end + @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} + @test iszero(zero_tangent(MutDemo(1.5))) + + @test zero_tangent((;a=1)) isa ZeroTangent + + @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] + @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] +end + From f45fbc76bf19f368caa50d1274da67d19d5487b4 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Sep 2023 17:31:37 +0800 Subject: [PATCH 10/34] export StructuralTangent --- src/ChainRulesCore.jl | 2 +- src/tangent_types/structural_tangent.jl | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 51db59b64..2a2f93c64 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -14,7 +14,7 @@ export ProjectTo, canonicalize, unthunk, zero_tangent # tangent operations export add!!, is_inplaceable_destination # gradient accumulation operations export ignore_derivatives, @ignore_derivatives # tangents -export Tangent, MutableTangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk +export StructuralTangent, Tangent, MutableTangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk include("debug_mode.jl") diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 864029572..25aec08b5 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -10,7 +10,6 @@ as an object with mirroring fields. `MutableTangent` is an experimental feature. Thus use of `StructuralTangent` (rather than `Tangent` directly) is also experimental. While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. - """ abstract type StructuralTangent{P} <: AbstractTangent end @@ -447,4 +446,4 @@ Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h) function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2} typeintersect(T1, T2) == Union{} && return false backing(t1)==backing(t2) -end \ No newline at end of file +end From b2bdb2670c409ea1bb91ef6f29d5641c4b708657 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 2 Oct 2023 12:02:19 +0800 Subject: [PATCH 11/34] Style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/tangent_types/abstract_zero.jl | 3 +-- test/tangent_types/structural_tangent.jl | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index be2c74241..8193ac28f 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -168,9 +168,8 @@ end @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} @test iszero(zero_tangent(MutDemo(1.5))) - @test zero_tangent((;a=1)) isa ZeroTangent + @test zero_tangent((; a=1)) isa ZeroTangent @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] end - diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 5736b2f1d..f4f753f47 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -430,7 +430,9 @@ end mutable struct MDemo x::Float64 end - function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x) + function ChainRulesCore.frule( + (_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x + ) y = setfield!(obj, field, x) ẏ = setproperty!(ȯbj, field, ẋ) return y, ẏ @@ -438,7 +440,7 @@ end @testset "usecase" begin obj = MDemo(99.0) - ∂obj = MutableTangent{MDemo}(;x=1.5) + ∂obj = MutableTangent{MDemo}(; x=1.5) frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) @test ∂obj.x == 10.0 @test obj.x == 95.0 @@ -450,14 +452,14 @@ end end @testset "== and hash" begin - @test MutableTangent{Any}(x=1.0) == MutableTangent{MDemo}(x=1.0) - @test MutableTangent{MDemo}(x=1.0) == MutableTangent{Any}(x=1.0) - @test MutableTangent{Any}(x=2.0) != MutableTangent{MDemo}(x=1.0) - @test MutableTangent{MDemo}(x=1.0) != MutableTangent{Any}(x=2.0) + @test MutableTangent{Any}(; x=1.0) == MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{Any}(; x=1.0) + @test MutableTangent{Any}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{Any}(; x=2.0) - nt = (;x=1.0) - @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(x=1.0) + nt = (; x=1.0) + @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0) - @test hash(MutableTangent{Any}(x=1.0)) == hash(MutableTangent{MDemo}(x=1.0)) + @test hash(MutableTangent{Any}(; x=1.0)) == hash(MutableTangent{MDemo}(; x=1.0)) end end \ No newline at end of file From 0438217811ed7b0006ce270d5930336a564f7b40 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 4 Oct 2023 15:45:22 +0800 Subject: [PATCH 12/34] handle unassigned a bit more --- src/tangent_types/abstract_zero.jl | 28 +++++++++++++++++++++++----- test/tangent_types/abstract_zero.jl | 25 +++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index dd9068b74..e167d6375 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -107,10 +107,9 @@ For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is d Exactly how it should be used (e.g. is it forward-mode only?) """ function zero_tangent end -zero_tangent(::AbstractString) = ZeroTangent() -# zero_tangent(::Number) = zero(x) # TODO: do we want this? -zero_tangent(primal::Array{<:Number}) = zero(primal) # TODO: do we want this? -zero_tangent(primal::Array) = map(zero_tangent, primal) + +zero_tangent(x::Number) = zero(x) + @generated function zero_tangent(primal) has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples zfield_exprs = map(fieldnames(primal)) do fname @@ -119,4 +118,23 @@ zero_tangent(primal::Array) = map(zero_tangent, primal) end backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...)) return :($MutableTangent{$primal}($backing_expr)) -end \ No newline at end of file +end + +function zero_tangent(x::Array{P, N}) where {P, N} + (isbitstype(P) || all(i->isassigned(x,i), eachindex(x))) && return map(zero_tangent, x) + + # Now we need to handle nonfully assigned arrays + # see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265 + y = Array{guess_zero_tangent_type(P), N}(undef, size(x)...) + @inbounds for n in eachindex(y) + if isassigned(x, n) + y[n] = zero_tangent(x[n]) + end + end + return y +end + +guess_zero_tangent_type(::Type{T}) where {T<:Number} = T +guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = Array{guess_zero_tangent_type(T), N} +guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634 + # TODO: we might be able to do better than this. even without. \ No newline at end of file diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 8193ac28f..7060575f1 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -172,4 +172,29 @@ end @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] + + @testset "undef elements" begin + x = Vector{Vector{Float64}}(undef, 3) + x[2] = [1.0,2.0] + dx = zero_tangent(x) + @test dx isa Vector{Vector{Float64}} + @test length(dx) == 3 + @test !isassigned(dx, 1) + @test dx[2] == [0.0, 0.0] + @test !isassigned(dx, 3) + + + a = Vector{MutDemo}(undef, 3) + a[2] = MutDemo(1.5) + da = zero_tangent(a) + @test !isassigned(da, 1) + @test iszero(da[2]) + @test !isassigned(da, 3) + + + db = zero_tangent(Vector{MutDemo}(undef, 3)) + @test all(ii->!isassigned(db,ii), eachindex(db)) + @test length(db)==3 + @test db isa Vector + end end From 4852c917b838da3c120f20b702f9ec3c043911ae Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 4 Oct 2023 15:58:59 +0800 Subject: [PATCH 13/34] add some more test cases to zero_tangent --- test/tangent_types/abstract_zero.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 7060575f1..1f6d857d9 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -165,10 +165,17 @@ end mutable struct MutDemo x::Float64 end + struct Demo + x::Float64 + end @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} @test iszero(zero_tangent(MutDemo(1.5))) @test zero_tangent((; a=1)) isa ZeroTangent + @test zero_tangent(Demo(1.2)) isa ZeroTangent + + @test zero_tangent(1) === 0 + @test zero_tangent(1.0) === 0.0 @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] From 0f82019e7fd8dba52df2a099ab34d1aeb31f3946 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 4 Oct 2023 23:32:59 +0800 Subject: [PATCH 14/34] style Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/abstract_zero.jl | 15 +++++++++------ test/tangent_types/abstract_zero.jl | 10 ++++------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index e167d6375..f79526cf9 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -120,12 +120,13 @@ zero_tangent(x::Number) = zero(x) return :($MutableTangent{$primal}($backing_expr)) end -function zero_tangent(x::Array{P, N}) where {P, N} - (isbitstype(P) || all(i->isassigned(x,i), eachindex(x))) && return map(zero_tangent, x) - +function zero_tangent(x::Array{P,N}) where {P,N} + (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) && + return map(zero_tangent, x) + # Now we need to handle nonfully assigned arrays # see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265 - y = Array{guess_zero_tangent_type(P), N}(undef, size(x)...) + y = Array{guess_zero_tangent_type(P),N}(undef, size(x)...) @inbounds for n in eachindex(y) if isassigned(x, n) y[n] = zero_tangent(x[n]) @@ -135,6 +136,8 @@ function zero_tangent(x::Array{P, N}) where {P, N} end guess_zero_tangent_type(::Type{T}) where {T<:Number} = T -guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = Array{guess_zero_tangent_type(T), N} +function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} + return Array{guess_zero_tangent_type(T),N} +end guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634 - # TODO: we might be able to do better than this. even without. \ No newline at end of file +# TODO: we might be able to do better than this. even without. \ No newline at end of file diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 1f6d857d9..eb9b757a1 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -182,7 +182,7 @@ end @testset "undef elements" begin x = Vector{Vector{Float64}}(undef, 3) - x[2] = [1.0,2.0] + x[2] = [1.0, 2.0] dx = zero_tangent(x) @test dx isa Vector{Vector{Float64}} @test length(dx) == 3 @@ -190,7 +190,6 @@ end @test dx[2] == [0.0, 0.0] @test !isassigned(dx, 3) - a = Vector{MutDemo}(undef, 3) a[2] = MutDemo(1.5) da = zero_tangent(a) @@ -198,10 +197,9 @@ end @test iszero(da[2]) @test !isassigned(da, 3) - db = zero_tangent(Vector{MutDemo}(undef, 3)) - @test all(ii->!isassigned(db,ii), eachindex(db)) - @test length(db)==3 + @test all(ii -> !isassigned(db, ii), eachindex(db)) + @test length(db) == 3 @test db isa Vector - end + end end From 557469193357eea4948e4284f7fbc448dfcf0c5e Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 6 Oct 2023 23:31:15 +0800 Subject: [PATCH 15/34] Handle Structs with undef fields --- src/tangent_types/abstract_zero.jl | 6 +++++- test/tangent_types/abstract_zero.jl | 18 +++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index f79526cf9..7f3cd0944 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -113,7 +113,11 @@ zero_tangent(x::Number) = zero(x) @generated function zero_tangent(primal) has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples zfield_exprs = map(fieldnames(primal)) do fname - fval = Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname))) + fval = if isdefined(primal, fname) + Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname))) + else + ZeroTangent() + end Expr(:kw, fname, fval) end backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...)) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index eb9b757a1..1e1b2f28a 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -180,7 +180,7 @@ end @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] - @testset "undef elements" begin + @testset "undef elements Vector" begin x = Vector{Vector{Float64}}(undef, 3) x[2] = [1.0, 2.0] dx = zero_tangent(x) @@ -202,4 +202,20 @@ end @test length(db) == 3 @test db isa Vector end + + @testset "undef fields struct" begin + dx = zero_tangent(Core.Box()) + @test dx.contents isa ZeroTangent + @test (dx.contents = 2.0) == 2.0 # should be assignable + + mutable struct MyPartiallyDefinedStruct + intro::Float64 + contents::Number + MyPartiallyDefinedStruct(x) = new(x) + end + dy = zero_tangent(MyPartiallyDefinedStruct(1.5)) + @test iszero(dy.intro) + @test iszero(dy.contents) + @test (dy.contents = 2.0) == 2.0 # should be assignable + end end From e9cc22194837dc0d69040476d36fa6ee415d3b08 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 22 Dec 2023 19:20:48 +0800 Subject: [PATCH 16/34] overhaul zero_tangent and MutableTangent for type stability --- src/tangent_types/abstract_zero.jl | 45 ++++-- src/tangent_types/structural_tangent.jl | 174 ++++++++++++++---------- test/tangent_types/abstract_zero.jl | 57 ++++++-- 3 files changed, 177 insertions(+), 99 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 7f3cd0944..94a7bc084 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -111,22 +111,40 @@ function zero_tangent end zero_tangent(x::Number) = zero(x) @generated function zero_tangent(primal) - has_mutable_tangent(primal) || return ZeroTangent() # note this takes care of tuples zfield_exprs = map(fieldnames(primal)) do fname - fval = if isdefined(primal, fname) - Expr(:call, zero_tangent, Expr(:call, getfield, :primal, QuoteNode(fname))) - else - ZeroTangent() - end + fval = :( + if isdefined(primal, $(QuoteNode(fname))) + zero_tangent(getfield(primal, $(QuoteNode(fname)))) + else + # This is going to be potentially bad, but that's what they get for not giving us a primal + # This will never me mutated inplace, rather it will alway be replaced with an actual value first + ZeroTangent() + end + ) Expr(:kw, fname, fval) end - backing_expr = Expr(:tuple, Expr(:parameters, zfield_exprs...)) - return :($MutableTangent{$primal}($backing_expr)) + + return if has_mutable_tangent(primal) + any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype + # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent + fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype)) + Expr(:kw, fname, fdef) + end + :($MutableTangent{$primal}( + $(Expr(:tuple, Expr(:parameters, any_mask...))), + $(Expr(:tuple, Expr(:parameters, zfield_exprs...))) + )) + else + :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...)))) + end end +zero_tangent(primal::Tuple) = Tangent{typeof(primal)}(map(zero_tangent, primal)...) + function zero_tangent(x::Array{P,N}) where {P,N} - (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) && + if (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) return map(zero_tangent, x) + end # Now we need to handle nonfully assigned arrays # see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265 @@ -139,9 +157,8 @@ function zero_tangent(x::Array{P,N}) where {P,N} return y end +# Sad heauristic methods we need because of unassigned values guess_zero_tangent_type(::Type{T}) where {T<:Number} = T -function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} - return Array{guess_zero_tangent_type(T),N} -end -guess_zero_tangent_type(::Any) = Any # if we had a general way to handle determining tangent type # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/634 -# TODO: we might be able to do better than this. even without. \ No newline at end of file +guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T))) +guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = return Array{guess_zero_tangent_type(T),N} +guess_zero_tangent_type(T::Type)= Any \ No newline at end of file diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 25aec08b5..192d58f8c 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -13,6 +13,90 @@ as an object with mirroring fields. """ abstract type StructuralTangent{P} <: AbstractTangent end +""" + Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent + +This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`. +`P` is the the corresponding primal type that this is a tangent for. + +`Tangent{P}` should have fields (technically properties), that match to a subset of the +fields of the primal type; and each should be a tangent type matching to the primal +type of that field. +Fields of the P that are not present in the Tangent are treated as `Zero`. + +`T` is an implementation detail representing the backing data structure. +For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`. +It should not be passed in by user. + +For `Tangent`s of `Tuple`s, `iterate` and `getindex` are overloaded to behave similarly +to for a tuple. +For `Tangent`s of `struct`s, `getproperty` is overloaded to allow for accessing values +via `tangent.fieldname`. +Any fields not explictly present in the `Tangent` are treated as being set to `ZeroTangent()`. +To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) +function is provided. +""" +struct Tangent{P,T} <: StructuralTangent{P} + # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict + # (but potentially a different one, as it doesn't contain tangents) + backing::T + + function Tangent{P,T}(backing) where {P,T} + if P <: Tuple + T <: Tuple || _backing_error(P, T, Tuple) + elseif P <: AbstractDict + T <: AbstractDict || _backing_error(P, T, AbstractDict) + elseif P === Any # can be anything + else # Any other struct (including NamedTuple) + T <: NamedTuple || _backing_error(P, T, NamedTuple) + end + return new(backing) + end +end + +""" + MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent + +This type represents the tangent to a mutable struct. +It itself is also mutable. + +!!! warning Exprimental + MutableTangent is an experimental feature, and is part of the mutation support featureset. + While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. + Exactly how it should be used (e.g. is it forward-mode only?) + +!!! warning Do not directly mess with the tangent backing data + It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values. + However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is). + If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this. +""" +struct MutableTangent{P,F} <: StructuralTangent{P} + backing::F + + function MutableTangent{P}(fieldvals) where P + backing = map(Ref, fieldvals) + return new{P, typeof(backing)}(backing) + end + function MutableTangent{P}( + any_mask::NamedTuple{names, <:NTuple{<:Any, Bool}}, fvals::NamedTuple{names} + ) where {names, P} + + backing = map(any_mask, fvals) do isany, fval + ref = if isany + Ref{Any} + else + Ref + end + return ref(fval) + end + return new{P, typeof(backing)}(backing) + end +end + +#################################################################### +# StructuralTangent Common + + function StructuralTangent{P}(nt::NamedTuple) where {P} if has_mutable_tangent(P) return MutableTangent{P}(nt) @@ -21,6 +105,7 @@ function StructuralTangent{P}(nt::NamedTuple) where {P} end end + has_mutable_tangent(::Type{P}) where P = ismutabletype(P) && (!isabstracttype(P) && fieldcount(P) > 0) @@ -40,6 +125,9 @@ end Base.iszero(t::StructuralTangent) = all(iszero, backing(t)) function Base.map(f, tangent::StructuralTangent{P}) where {P} + #TODO: is it even useful to support this on MutableTangents? + #TODO: we implictly assume only linear `f` are called and that it is safe to ignore noncanonical Zeros + # This feels like a fair assumption since all normal operations on tangents are linear L = propertynames(backing(tangent)) vals = map(f, Tuple(backing(tangent))) named_vals = NamedTuple{L,typeof(vals)}(vals) @@ -63,7 +151,8 @@ primal types. backing(x::Tuple) = x backing(x::NamedTuple) = x backing(x::Dict) = x -backing(x::StructuralTangent) = getfield(x, :backing) +backing(x::Tangent) = getfield(x, :backing) +backing(x::MutableTangent) = map(getindex, getfield(x, :backing)) # For generic structs function backing(x::T)::NamedTuple where {T} @@ -206,46 +295,8 @@ function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P} return println(io) end -""" - Tangent{P, T} <: StructuralTangent{P} <: AbstractTangent - -This type represents the tangent for a `struct`/`NamedTuple`, or `Tuple`. -`P` is the the corresponding primal type that this is a tangent for. - -`Tangent{P}` should have fields (technically properties), that match to a subset of the -fields of the primal type; and each should be a tangent type matching to the primal -type of that field. -Fields of the P that are not present in the Tangent are treated as `Zero`. - -`T` is an implementation detail representing the backing data structure. -For Tuple it will be a Tuple, and for everything else it will be a `NamedTuple`. -It should not be passed in by user. - -For `Tangent`s of `Tuple`s, `iterate` and `getindex` are overloaded to behave similarly -to for a tuple. -For `Tangent`s of `struct`s, `getproperty` is overloaded to allow for accessing values -via `tangent.fieldname`. -Any fields not explictly present in the `Tangent` are treated as being set to `ZeroTangent()`. -To make a `Tangent` have all the fields of the primal the [`canonicalize`](@ref) -function is provided. -""" -struct Tangent{P,T} <: StructuralTangent{P} - # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict - # (but potentially a different one, as it doesn't contain tangents) - backing::T - - function Tangent{P,T}(backing) where {P,T} - if P <: Tuple - T <: Tuple || _backing_error(P, T, Tuple) - elseif P <: AbstractDict - T <: AbstractDict || _backing_error(P, T, AbstractDict) - elseif P === Any # can be anything - else # Any other struct (including NamedTuple) - T <: NamedTuple || _backing_error(P, T, NamedTuple) - end - return new(backing) - end -end +####################################### +# immutable Tangent function Tangent{P}(; kwargs...) where {P} backing = (; kwargs...) # construct as NamedTuple @@ -401,46 +452,19 @@ canonicalize(tangent::Tangent{Any,<:NamedTuple{L}}) where {L} = tangent canonicalize(tangent::Tangent{Any,<:Tuple}) = tangent canonicalize(tangent::Tangent{Any,<:AbstractDict}) = tangent - -""" - MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent - -This type represents the tangent to a mutable struct. -It itself is also mutable. - -!!! warning Exprimental - MutableTangent is an experimental feature, and is part of the mutation support featureset. - While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. - Exactly how it should be used (e.g. is it forward-mode only?) - -!!! warning Do not directly mess with the tangent backing data - It is relatively straight forward for a forwards-mode AD to work correctly in the presence of mutation and aliasing of primal values. - However, this requires that the tangent is aliased in turn and conversely that it is copied when the primal is). - If you seperately alias the backing data, etc by using the internal `ChainRulesCore.backing` function you can break this. -""" -mutable struct MutableTangent{P} <: StructuralTangent{P} - #TODO: we may want to absolutely lock the type of this down - backing::NamedTuple -end +################################################### +# MutableTangent MutableTangent{P}(;kwargs...) where P = MutableTangent{P}(NamedTuple(kwargs)) -Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(backing(tangent), idx) -Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(backing(tangent), idx) # break ambig +ref_backing(t::MutableTangent) = getfield(t, :backing) -function Base.setproperty!(tangent::MutableTangent, name::Symbol, x) - new_backing = Base.setindex(backing(tangent), x, name) - setfield!(tangent, :backing, new_backing) - return x -end +Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(ref_backing(tangent), idx)[] +Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(ref_backing(tangent), idx)[] # break ambig -function Base.setproperty!(tangent::MutableTangent, idx::Int, x) - # needed due to https://github.com/JuliaLang/julia/issues/43155 - name = idx2sym(backing(tangent), idx) - return setproperty!(tangent, name, x) -end +Base.setproperty!(tangent::MutableTangent, name::Symbol, x) = getproperty(ref_backing(tangent), name)[] = x +Base.setproperty!(tangent::MutableTangent, idx::Int, x) = getproperty(ref_backing(tangent), idx)[] = x # break ambig -idx2sym(::NamedTuple{names}, idx) where names = names[idx] Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h) function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2} diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 1e1b2f28a..5e987d613 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -162,6 +162,8 @@ end @testset "zero_tangent" begin + @test zero_tangent(1) === 0 + @test zero_tangent(1.0) === 0.0 mutable struct MutDemo x::Float64 end @@ -171,34 +173,34 @@ end @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} @test iszero(zero_tangent(MutDemo(1.5))) - @test zero_tangent((; a=1)) isa ZeroTangent - @test zero_tangent(Demo(1.2)) isa ZeroTangent - - @test zero_tangent(1) === 0 - @test zero_tangent(1.0) === 0.0 + @test zero_tangent((; a=1)) isa Tangent{typeof((;a=1))} + @test zero_tangent(Demo(1.2)) isa Tangent{Demo} + @test zero_tangent(Demo(1.2)).x === 0.0 @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] + @test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0) + @testset "undef elements Vector" begin x = Vector{Vector{Float64}}(undef, 3) x[2] = [1.0, 2.0] dx = zero_tangent(x) @test dx isa Vector{Vector{Float64}} @test length(dx) == 3 - @test !isassigned(dx, 1) + @test !isassigned(dx, 1) # We may reconsider this later @test dx[2] == [0.0, 0.0] - @test !isassigned(dx, 3) + @test !isassigned(dx, 3) # We may reconsider this later a = Vector{MutDemo}(undef, 3) a[2] = MutDemo(1.5) da = zero_tangent(a) - @test !isassigned(da, 1) + @test !isassigned(da, 1) # We may reconsider this later @test iszero(da[2]) - @test !isassigned(da, 3) + @test !isassigned(da, 3) # We may reconsider this later db = zero_tangent(Vector{MutDemo}(undef, 3)) - @test all(ii -> !isassigned(db, ii), eachindex(db)) + @test all(ii -> !isassigned(db, ii), eachindex(db)) # We may reconsider this later @test length(db) == 3 @test db isa Vector end @@ -217,5 +219,40 @@ end @test iszero(dy.intro) @test iszero(dy.contents) @test (dy.contents = 2.0) == 2.0 # should be assignable + + mutable struct MyPartiallyDefinedStructWithAnys + intro::Float64 + contents::Any + MyPartiallyDefinedStructWithAnys(x) = new(x) + end + dy = zero_tangent(MyPartiallyDefinedStructWithAnys(1.5)) + @test iszero(dy.intro) + @test iszero(dy.contents) + @test dy.contents === ZeroTangent() # we just don't know anything about this data + @test (dy.contents = 2.0) == 2.0 # should be assignable + @test (dy.contents = [2.0, 4.0]) == [2.0, 4.0] # should be assignable to different values + + mutable struct MyStructWithNonConcreteFields + x::Any + y::Union{Float64, Vector{Float64}} + z::AbstractVector + end + d = zero_tangent(MyStructWithNonConcreteFields(1.0, 2.0, [3.0])) + @test iszero(d.x) + d.x = Tangent{Base.RefValue{Float64}}(x=1.5) + @test d.x == Tangent{Base.RefValue{Float64}}(x=1.5) #should be assignable + d.x=2.4 + @test d.x == 2.4 #should be assignable + @test iszero(d.y) + d.y=2.4 + @test d.y == 2.4 #should be assignable + d.y=[2.4] + @test d.y == [2.4] #should be assignable + @test iszero(d.z) + d.z = [1.0, 2.0] + @test d.z == [1.0, 2.0] + d.z = @view [2.0,3.0,4.0][1:2] + @test d.z == [2.0, 3.0] + @test d.z isa SubArray end end From baea9d3e985396a76983d4cdba9d3a3ff1b176b6 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 27 Dec 2023 11:37:48 +0800 Subject: [PATCH 17/34] set MutableTangent setproperty! on index --- src/tangent_types/structural_tangent.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 192d58f8c..f68eefb76 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -462,8 +462,8 @@ ref_backing(t::MutableTangent) = getfield(t, :backing) Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(ref_backing(tangent), idx)[] Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(ref_backing(tangent), idx)[] # break ambig -Base.setproperty!(tangent::MutableTangent, name::Symbol, x) = getproperty(ref_backing(tangent), name)[] = x -Base.setproperty!(tangent::MutableTangent, idx::Int, x) = getproperty(ref_backing(tangent), idx)[] = x # break ambig +Base.setproperty!(tangent::MutableTangent, name::Symbol, x) = getfield(ref_backing(tangent), name)[] = x +Base.setproperty!(tangent::MutableTangent, idx::Int, x) = getfield(ref_backing(tangent), idx)[] = x # break ambig Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h) From a27f1b60628316c7bb53dde707713572c84423be Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 27 Dec 2023 11:43:13 +0800 Subject: [PATCH 18/34] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/structural_tangent.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index f68eefb76..7262d4da7 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -462,9 +462,12 @@ ref_backing(t::MutableTangent) = getfield(t, :backing) Base.getproperty(tangent::MutableTangent, idx::Symbol) = getfield(ref_backing(tangent), idx)[] Base.getproperty(tangent::MutableTangent, idx::Int) = getfield(ref_backing(tangent), idx)[] # break ambig -Base.setproperty!(tangent::MutableTangent, name::Symbol, x) = getfield(ref_backing(tangent), name)[] = x -Base.setproperty!(tangent::MutableTangent, idx::Int, x) = getfield(ref_backing(tangent), idx)[] = x # break ambig - +function Base.setproperty!(tangent::MutableTangent, name::Symbol, x) + return getfield(ref_backing(tangent), name)[] = x +end +function Base.setproperty!(tangent::MutableTangent, idx::Int, x) + return getfield(ref_backing(tangent), idx)[] = x +end # break ambig Base.hash(tangent::MutableTangent, h::UInt64) = hash(backing(tangent), h) function Base.:(==)(t1::MutableTangent{T1}, t2::MutableTangent{T2}) where {T1, T2} From 4cfce0bc4346648f66933d8104b27b71f2726ed8 Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Dec 2023 16:09:15 +0800 Subject: [PATCH 19/34] handle abstract fields right in mutable tangents outside of zero tangent --- src/tangent_types/structural_tangent.jl | 10 +++--- test/tangent_types/structural_tangent.jl | 41 +++++++++++++++++++++--- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 7262d4da7..c7ae1b1b5 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -73,10 +73,6 @@ It itself is also mutable. struct MutableTangent{P,F} <: StructuralTangent{P} backing::F - function MutableTangent{P}(fieldvals) where P - backing = map(Ref, fieldvals) - return new{P, typeof(backing)}(backing) - end function MutableTangent{P}( any_mask::NamedTuple{names, <:NTuple{<:Any, Bool}}, fvals::NamedTuple{names} ) where {names, P} @@ -91,8 +87,14 @@ struct MutableTangent{P,F} <: StructuralTangent{P} end return new{P, typeof(backing)}(backing) end + + function MutableTangent{P}(fvals) where P + any_mask = NamedTuple{fieldnames(P)}((!isconcretetype).(fieldtypes(P))) + return MutableTangent{P}(any_mask, fvals) + end end + #################################################################### # StructuralTangent Common diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index f4f753f47..8ab5a6bc6 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -4,6 +4,11 @@ struct Foo y::Float64 end +mutable struct MFoo + x::Float64 + y +end + # For testing Primal + Tangent performance struct Bar x::Float64 @@ -452,14 +457,40 @@ end end @testset "== and hash" begin - @test MutableTangent{Any}(; x=1.0) == MutableTangent{MDemo}(; x=1.0) - @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{Any}(; x=1.0) - @test MutableTangent{Any}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) - @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{Any}(; x=2.0) + @test MutableTangent{MDemo}(; x=1f0) == MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1f0) + @test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0) nt = (; x=1.0) @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0) - @test hash(MutableTangent{Any}(; x=1.0)) == hash(MutableTangent{MDemo}(; x=1.0)) + @test hash(MutableTangent{MDemo}(; x=1f0)) == hash(MutableTangent{MDemo}(; x=1.0)) + end + + @testset "Mutation" begin + v = MutableTangent{MFoo}(x=1.5, y=2.4) + v.x = 1.6 + @test v == MutableTangent{MFoo}(x=1.6, y=2.4) + v.y = [1.0, 2.0] # change type, because primal can change type + @test v == MutableTangent{MFoo}(x=1.6, y=[1.0, 2.0]) + end +end + +@testset "map" begin + @testset "Tangent" begin + ∂foo = Tangent{Foo}(x=1.5, y=2.4) + @test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0, y=4.8) + + ∂foo = Tangent{Foo}(x=1.5) + @test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0) + end + @testset "MutableTangent" begin + ∂foo = MutableTangent{MFoo}(x=1.5, y=2.4) + ∂foo2 = map(v->2*v, ∂foo) + @test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=4.8) + # Check can still be mutated to new typ + ∂foo2.y=[1.0, 2.0] + @test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=[1.0, 2.0]) end end \ No newline at end of file From ad9a5af6463dee12f782610b337f08f56a1f0e0b Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Dec 2023 16:56:52 +0800 Subject: [PATCH 20/34] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/abstract_zero.jl | 11 +++++---- test/tangent_types/abstract_zero.jl | 17 +++++++------- test/tangent_types/structural_tangent.jl | 30 ++++++++++++------------ 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 94a7bc084..a0f8d1f82 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -123,7 +123,6 @@ zero_tangent(x::Number) = zero(x) ) Expr(:kw, fname, fval) end - return if has_mutable_tangent(primal) any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent @@ -132,11 +131,11 @@ zero_tangent(x::Number) = zero(x) end :($MutableTangent{$primal}( $(Expr(:tuple, Expr(:parameters, any_mask...))), - $(Expr(:tuple, Expr(:parameters, zfield_exprs...))) + $(Expr(:tuple, Expr(:parameters, zfield_exprs...))), )) else :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...)))) - end + end end zero_tangent(primal::Tuple) = Tangent{typeof(primal)}(map(zero_tangent, primal)...) @@ -160,5 +159,7 @@ end # Sad heauristic methods we need because of unassigned values guess_zero_tangent_type(::Type{T}) where {T<:Number} = T guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T))) -guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} = return Array{guess_zero_tangent_type(T),N} -guess_zero_tangent_type(T::Type)= Any \ No newline at end of file +function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} + return Array{guess_zero_tangent_type(T),N} +end +guess_zero_tangent_type(T::Type) = Any \ No newline at end of file diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 5e987d613..81114511a 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -173,7 +173,7 @@ end @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} @test iszero(zero_tangent(MutDemo(1.5))) - @test zero_tangent((; a=1)) isa Tangent{typeof((;a=1))} + @test zero_tangent((; a=1)) isa Tangent{typeof((; a = 1))} @test zero_tangent(Demo(1.2)) isa Tangent{Demo} @test zero_tangent(Demo(1.2)).x === 0.0 @@ -181,7 +181,6 @@ end @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] @test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0) - @testset "undef elements Vector" begin x = Vector{Vector{Float64}}(undef, 3) x[2] = [1.0, 2.0] @@ -234,24 +233,24 @@ end mutable struct MyStructWithNonConcreteFields x::Any - y::Union{Float64, Vector{Float64}} + y::Union{Float64,Vector{Float64}} z::AbstractVector end d = zero_tangent(MyStructWithNonConcreteFields(1.0, 2.0, [3.0])) @test iszero(d.x) - d.x = Tangent{Base.RefValue{Float64}}(x=1.5) - @test d.x == Tangent{Base.RefValue{Float64}}(x=1.5) #should be assignable - d.x=2.4 + d.x = Tangent{Base.RefValue{Float64}}(; x=1.5) + @test d.x == Tangent{Base.RefValue{Float64}}(; x=1.5) #should be assignable + d.x = 2.4 @test d.x == 2.4 #should be assignable @test iszero(d.y) - d.y=2.4 + d.y = 2.4 @test d.y == 2.4 #should be assignable - d.y=[2.4] + d.y = [2.4] @test d.y == [2.4] #should be assignable @test iszero(d.z) d.z = [1.0, 2.0] @test d.z == [1.0, 2.0] - d.z = @view [2.0,3.0,4.0][1:2] + d.z = @view [2.0, 3.0, 4.0][1:2] @test d.z == [2.0, 3.0] @test d.z isa SubArray end diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 8ab5a6bc6..0982f97c0 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -457,40 +457,40 @@ end end @testset "== and hash" begin - @test MutableTangent{MDemo}(; x=1f0) == MutableTangent{MDemo}(; x=1.0) - @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1f0) + @test MutableTangent{MDemo}(; x=1.0f0) == MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1.0f0) @test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0) nt = (; x=1.0) @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0) - @test hash(MutableTangent{MDemo}(; x=1f0)) == hash(MutableTangent{MDemo}(; x=1.0)) + @test hash(MutableTangent{MDemo}(; x=1.0f0)) == hash(MutableTangent{MDemo}(; x=1.0)) end @testset "Mutation" begin - v = MutableTangent{MFoo}(x=1.5, y=2.4) + v = MutableTangent{MFoo}(; x=1.5, y=2.4) v.x = 1.6 - @test v == MutableTangent{MFoo}(x=1.6, y=2.4) + @test v == MutableTangent{MFoo}(; x=1.6, y=2.4) v.y = [1.0, 2.0] # change type, because primal can change type - @test v == MutableTangent{MFoo}(x=1.6, y=[1.0, 2.0]) + @test v == MutableTangent{MFoo}(; x=1.6, y=[1.0, 2.0]) end end @testset "map" begin @testset "Tangent" begin - ∂foo = Tangent{Foo}(x=1.5, y=2.4) - @test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0, y=4.8) + ∂foo = Tangent{Foo}(; x=1.5, y=2.4) + @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0, y=4.8) - ∂foo = Tangent{Foo}(x=1.5) - @test map(v->2*v, ∂foo) == Tangent{Foo}(x=3.0) + ∂foo = Tangent{Foo}(; x=1.5) + @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0) end @testset "MutableTangent" begin - ∂foo = MutableTangent{MFoo}(x=1.5, y=2.4) - ∂foo2 = map(v->2*v, ∂foo) - @test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=4.8) + ∂foo = MutableTangent{MFoo}(; x=1.5, y=2.4) + ∂foo2 = map(v -> 2 * v, ∂foo) + @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=4.8) # Check can still be mutated to new typ - ∂foo2.y=[1.0, 2.0] - @test ∂foo2 == MutableTangent{MFoo}(x=3.0, y=[1.0, 2.0]) + ∂foo2.y = [1.0, 2.0] + @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=[1.0, 2.0]) end end \ No newline at end of file From 8b3d5251c936fdf34092d52ff50a57e49878172d Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 28 Dec 2023 17:55:36 +0800 Subject: [PATCH 21/34] Add docs for forward mutation support --- docs/make.jl | 1 + docs/src/api.md | 2 +- .../superpowers/mutation_support.md | 73 +++++++++++++++++++ 3 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 docs/src/rule_author/superpowers/mutation_support.md diff --git a/docs/make.jl b/docs/make.jl index ad86a84ae..1666665fe 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -61,6 +61,7 @@ makedocs(; "`@opt_out`" => "rule_author/superpowers/opt_out.md", "`RuleConfig`" => "rule_author/superpowers/ruleconfig.md", "Gradient accumulation" => "rule_author/superpowers/gradient_accumulation.md", + "Mutation Support (experimental)" => "rule_author/superpowers/mutation_support.md", ], "Converting ZygoteRules.@adjoint to rrules" => "rule_author/converting_zygoterules.md", "Tips for making your package work with AD" => "rule_author/tips_for_packages.md", diff --git a/docs/src/api.md b/docs/src/api.md index 5648058e0..57b7bf2ad 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -20,7 +20,7 @@ Modules = [ChainRulesCore] Pages = [ "tangent_types/abstract_zero.jl", "tangent_types/one.jl", - "tangent_types/tangent.jl", + "tangent_types/structural_tangent.jl", "tangent_types/thunks.jl", "tangent_types/abstract_tangent.jl", "tangent_types/notimplemented.jl", diff --git a/docs/src/rule_author/superpowers/mutation_support.md b/docs/src/rule_author/superpowers/mutation_support.md new file mode 100644 index 000000000..55629166a --- /dev/null +++ b/docs/src/rule_author/superpowers/mutation_support.md @@ -0,0 +1,73 @@ +# Mutation Support + +ChainRulesCore.jl offers experimental support for mutation, targetting use in forward mode AD. +(Mutation support in reverse mode AD is more complicated and will likely require more changes to the interface) + +!!! warning "Experimental" + This page documents an experimental feature. + Expect breaking changes in minor versions while this remains. + It is not suitable for general use unless you are prepared to modify how you are using it each minor release. + It is thus suggested that if you are using it to use _tilde_ bounds on supported minor versions. + + +## `MutableTangent` +The [`MutableTangent`](@ref) type is designed to be a partner to the [`Tangent`](@ref) type, with specific support for being mutated in place. +It is required to be a structural tangent, having one tangent for each field of the primal object. + +Technically, not all `mutable struct`s need to use `MutableTangent` to represent their tangents. +Just like not all `struct`s need to use `Tangent`s. +Common examples away from this are natural tangent types like for arrays. +However, if one is setting up to use a custom tangent type for this it is surficiently off the beated path that we can not provide much guidance. + +## `zero_tangent` + +The [`zero_tangent`](@ref) function functions to give you a zero (i.e. additive identity) for any primal value. +The [`ZeroTangent`](@ref) type also does this. +The difference is that [`zero_tangent`](@ref) is (where possible) a full structural tangent mirroring the structure of the primal. +For mutation support this is important, since it means that there is mutable memory available in the tangent to be mutated when the primal changes. +To support this you thus need to make sure your zeros are created in various places with [`zero_tangent`](@ref) rather than []`ZeroTangent`](@ref). + +It is also useful for reasons of type stability, since it is always a structural tangent. +For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not. + +## Writing a frule for a mutating function +It is relatively straight forward to write a frule for a mutating function. +There are a few key points to follow: + - There must be a mutable tangent input for every mutated primal input + - When the primal value is changed, the corresponding change must be made to its tangent partner + - When a value is returned, return its partnered tangent. + + +### Example +For example, consider the primal function with: +1. takes two `Ref`s +2. doubles the first one inplace +3. overwrites the second one's value with the literal 5.0 +4. returns the first one + + +```julia +function foo!(a::Base.RefValue, b::Base.RefValue) + a[] *= 2 + b[] = 5.0 + return a +end +``` + +The frule for this would be: +```julia +function ChainRulesCore.frule((ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::Base.RefValue) + @assert ȧ isa MutableTangent{typeof(a)} + @assert ḃ isa MutableTangent{typeof(b)} + + a[] *= 2 + ȧ.x *= 2 # `.x` is the field that lives behind RefValues + + b[]=5.0 + ḃ.x = zero_tangent(5.0) # or since we know that the zero for a Float64 is zero could write `ḃ.x = 0.0` + + return a, ȧ +end +``` + +Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct. \ No newline at end of file From c09ff910e268b3c63b20952e2a5065530fc1d5ff Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 29 Dec 2023 12:40:07 +0800 Subject: [PATCH 22/34] use ismutabletype from Compat --- Project.toml | 14 +++++++------- src/ChainRulesCore.jl | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 195cccac5..63fd50344 100644 --- a/Project.toml +++ b/Project.toml @@ -7,20 +7,17 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -[weakdeps] -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[extensions] -ChainRulesCoreSparseArraysExt = "SparseArrays" - [compat] BenchmarkTools = "0.5" -Compat = "2, 3, 4" +Compat = "3.40, 4" FiniteDifferences = "0.10" OffsetArrays = "1" StaticArrays = "0.11, 0.12, 1" julia = "1.6" +[extensions] +ChainRulesCoreSparseArraysExt = "SparseArrays" + [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" @@ -31,3 +28,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "BenchmarkTools", "FiniteDifferences", "OffsetArrays", "SparseArrays", "StaticArrays"] + +[weakdeps] +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 2a2f93c64..bda392497 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,7 +2,7 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using Base.Meta using LinearAlgebra -using Compat: hasfield, hasproperty +using Compat: hasfield, hasproperty, ismutabletype export frule, rrule # core function # rule configurations From 59fc470486e47b464de0bc65cb3d2f33ca5a782f Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 29 Dec 2023 12:41:11 +0800 Subject: [PATCH 23/34] wrap structural tangent tests in a common testset --- test/tangent_types/structural_tangent.jl | 799 ++++++++++++----------- 1 file changed, 400 insertions(+), 399 deletions(-) diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 0982f97c0..16d702c14 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -22,475 +22,476 @@ struct StructWithInvariant StructWithInvariant(x) = new(x, 2x) end +@testset "StructuralTangent" begin + @testset "Tangent" begin + @testset "empty types" begin + @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} + end -@testset "Tangent" begin - @testset "empty types" begin - @test typeof(Tangent{Tuple{}}()) == Tangent{Tuple{},Tuple{}} - end - - @testset "constructor" begin - t = (1.0, 2.0) - nt = (x=1, y=2.0) - d = Dict(:x => 1.0, :y => 2.0) - vals = [1, 2] - - @test_throws ArgumentError Tangent{typeof(t),typeof(nt)}(nt) - @test_throws ArgumentError Tangent{typeof(t),typeof(d)}(d) - - @test_throws ArgumentError Tangent{typeof(d),typeof(nt)}(nt) - @test_throws ArgumentError Tangent{typeof(d),typeof(t)}(t) + @testset "constructor" begin + t = (1.0, 2.0) + nt = (x=1, y=2.0) + d = Dict(:x => 1.0, :y => 2.0) + vals = [1, 2] - @test_throws ArgumentError Tangent{typeof(nt),typeof(vals)}(vals) - @test_throws ArgumentError Tangent{typeof(nt),typeof(d)}(d) - @test_throws ArgumentError Tangent{typeof(nt),typeof(t)}(t) + @test_throws ArgumentError Tangent{typeof(t),typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(t),typeof(d)}(d) - @test_throws ArgumentError Tangent{Foo,typeof(d)}(d) - @test_throws ArgumentError Tangent{Foo,typeof(t)}(t) - end + @test_throws ArgumentError Tangent{typeof(d),typeof(nt)}(nt) + @test_throws ArgumentError Tangent{typeof(d),typeof(t)}(t) - @testset "==" begin - @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) - @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) - @test Tangent{Foo}(; y=2.5, x=ZeroTangent()) == Tangent{Foo}(; y=2.5) + @test_throws ArgumentError Tangent{typeof(nt),typeof(vals)}(vals) + @test_throws ArgumentError Tangent{typeof(nt),typeof(d)}(d) + @test_throws ArgumentError Tangent{typeof(nt),typeof(t)}(t) - @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) - @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) + @test_throws ArgumentError Tangent{Foo,typeof(d)}(d) + @test_throws ArgumentError Tangent{Foo,typeof(t)}(t) + end - tup = (1.0, 2.0) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) - @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) + @testset "==" begin + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; x=0.1, y=2.5) + @test Tangent{Foo}(; x=0.1, y=2.5) == Tangent{Foo}(; y=2.5, x=0.1) + @test Tangent{Foo}(; y=2.5, x=ZeroTangent()) == Tangent{Foo}(; y=2.5) - @test Tangent{Foo}(; y=2.0) == Tangent{Foo}(; x=ZeroTangent(), y=Float32(2.0)) - end + @test Tangent{Tuple{Float64}}(2.0) == Tangent{Tuple{Float64}}(2.0) + @test Tangent{Dict}(Dict(4 => 3)) == Tangent{Dict}(Dict(4 => 3)) - @testset "hash" begin - @test hash(Tangent{Foo}(; x=0.1, y=2.5)) == hash(Tangent{Foo}(; y=2.5, x=0.1)) - @test hash(Tangent{Foo}(; y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(; y=2.5)) - end + tup = (1.0, 2.0) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, @thunk(2 * 1.0)) + @test Tangent{typeof(tup)}(1.0, 2.0) == Tangent{typeof(tup)}(1.0, 2) - @testset "indexing, iterating, and properties" begin - @test keys(Tangent{Foo}(; x=2.5)) == (:x,) - @test propertynames(Tangent{Foo}(; x=2.5)) == (:x,) - @test haskey(Tangent{Foo}(; x=2.5), :x) == true - if isdefined(Base, :hasproperty) - @test hasproperty(Tangent{Foo}(; x=2.5), :y) == false + @test Tangent{Foo}(; y=2.0) == Tangent{Foo}(; x=ZeroTangent(), y=Float32(2.0)) end - @test Tangent{Foo}(; x=2.5).x == 2.5 - - tang1 = Tangent{Tuple{Float64}}(2.0) - @test keys(tang1) == Base.OneTo(1) - @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) - @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 - @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 - @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 - @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 - @test NoTangent() === @inferred Base.tail(tang1) - @test NoTangent() === @inferred Base.tail(Tangent{Tuple{}}()) - - tang3 = Tangent{Tuple{Float64, String, Vector{Float64}}}(1.0, NoTangent(), @thunk [3.0] .+ 4) - @test @inferred(first(tang3)) === tang3[1] === 1.0 - @test @inferred(last(tang3)) isa Thunk - @test unthunk(last(tang3)) == [7.0] - @test Tuple(@inferred Base.tail(tang3))[1] === NoTangent() - @test Tuple(Base.tail(tang3))[end] isa Thunk - - NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} - @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 - @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() - @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() - @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 - - @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 - @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() - @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() - @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 - - @test first(Tangent{NT}(; a=(@thunk 2.0^2))) isa Thunk - @test unthunk(first(Tangent{NT}(; a=(@thunk 2.0^2)))) == 4.0 - @test last(Tangent{NT}(; a=(@thunk 2.0^2))) isa ZeroTangent - - ntang1 = @inferred Base.tail(Tangent{NT}(; b=(@thunk 2.0^2))) - @test ntang1 isa Tangent{<:NamedTuple{(:b,)}} - @test NoTangent() === @inferred Base.tail(ntang1) - - # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 - # if VERSION >= v"1.8-" - # @test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true - # else - # @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true - # end - @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false - - @test length(Tangent{Foo}(; x=2.5)) == 1 - @test length(Tangent{Tuple{Float64}}(2.0)) == 1 - - @test eltype(Tangent{Foo}(; x=2.5)) == Float64 - @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 - - # Testing iterate via collect - @test collect(Tangent{Foo}(; x=2.5)) == [2.5] - @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] - - # Test indexed_iterate - ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) - _unpack2tuple = function (tangent) - a, b = tangent - return (a, b) + + @testset "hash" begin + @test hash(Tangent{Foo}(; x=0.1, y=2.5)) == hash(Tangent{Foo}(; y=2.5, x=0.1)) + @test hash(Tangent{Foo}(; y=2.5, x=ZeroTangent())) == hash(Tangent{Foo}(; y=2.5)) end - @inferred _unpack2tuple(ctup) - @test _unpack2tuple(ctup) === (2.0, 3) - - # Test getproperty is inferrable - _unpacknamedtuple = tangent -> (tangent.x, tangent.y) - if VERSION ≥ v"1.2" - @inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0)) - @inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0)) + + @testset "indexing, iterating, and properties" begin + @test keys(Tangent{Foo}(; x=2.5)) == (:x,) + @test propertynames(Tangent{Foo}(; x=2.5)) == (:x,) + @test haskey(Tangent{Foo}(; x=2.5), :x) == true + if isdefined(Base, :hasproperty) + @test hasproperty(Tangent{Foo}(; x=2.5), :y) == false + end + @test Tangent{Foo}(; x=2.5).x == 2.5 + + tang1 = Tangent{Tuple{Float64}}(2.0) + @test keys(tang1) == Base.OneTo(1) + @test propertynames(Tangent{Tuple{Float64}}(2.0)) == (1,) + @test getindex(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getindex(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + @test getproperty(Tangent{Tuple{Float64}}(2.0), 1) == 2.0 + @test getproperty(Tangent{Tuple{Float64}}(@thunk 2.0^2), 1) == 4.0 + @test NoTangent() === @inferred Base.tail(tang1) + @test NoTangent() === @inferred Base.tail(Tangent{Tuple{}}()) + + tang3 = Tangent{Tuple{Float64, String, Vector{Float64}}}(1.0, NoTangent(), @thunk [3.0] .+ 4) + @test @inferred(first(tang3)) === tang3[1] === 1.0 + @test @inferred(last(tang3)) isa Thunk + @test unthunk(last(tang3)) == [7.0] + @test Tuple(@inferred Base.tail(tang3))[1] === NoTangent() + @test Tuple(Base.tail(tang3))[end] isa Thunk + + NT = NamedTuple{(:a, :b),Tuple{Float64,Float64}} + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getindex(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getindex(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 + + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :a) == 4.0 + @test getproperty(Tangent{NT}(; a=(@thunk 2.0^2)), :b) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 1) == ZeroTangent() + @test getproperty(Tangent{NT}(; b=(@thunk 2.0^2)), 2) == 4.0 + + @test first(Tangent{NT}(; a=(@thunk 2.0^2))) isa Thunk + @test unthunk(first(Tangent{NT}(; a=(@thunk 2.0^2)))) == 4.0 + @test last(Tangent{NT}(; a=(@thunk 2.0^2))) isa ZeroTangent + + ntang1 = @inferred Base.tail(Tangent{NT}(; b=(@thunk 2.0^2))) + @test ntang1 isa Tangent{<:NamedTuple{(:b,)}} + @test NoTangent() === @inferred Base.tail(ntang1) + + # TODO: uncomment this once https://github.com/JuliaLang/julia/issues/35516 + # if VERSION >= v"1.8-" + # @test haskey(Tangent{Tuple{Float64}}(2.0), 1) == true + # else + # @test_broken haskey(Tangent{Tuple{Float64}}(2.0), 1) == true + # end + @test_broken hasproperty(Tangent{Tuple{Float64}}(2.0), 2) == false + + @test length(Tangent{Foo}(; x=2.5)) == 1 + @test length(Tangent{Tuple{Float64}}(2.0)) == 1 + + @test eltype(Tangent{Foo}(; x=2.5)) == Float64 + @test eltype(Tangent{Tuple{Float64}}(2.0)) == Float64 + + # Testing iterate via collect + @test collect(Tangent{Foo}(; x=2.5)) == [2.5] + @test collect(Tangent{Tuple{Float64}}(2.0)) == [2.0] + + # Test indexed_iterate + ctup = Tangent{Tuple{Float64,Int64}}(2.0, 3) + _unpack2tuple = function (tangent) + a, b = tangent + return (a, b) + end + @inferred _unpack2tuple(ctup) + @test _unpack2tuple(ctup) === (2.0, 3) + + # Test getproperty is inferrable + _unpacknamedtuple = tangent -> (tangent.x, tangent.y) + if VERSION ≥ v"1.2" + @inferred _unpacknamedtuple(Tangent{Foo}(; x=2, y=3.0)) + @inferred _unpacknamedtuple(Tangent{Foo}(; y=3.0)) + end end - end - @testset "reverse" begin - c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") - cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) - @test reverse(c) === cr + @testset "reverse" begin + c = Tangent{Tuple{Int,Int,String}}(1, 2, "something") + cr = Tangent{Tuple{String,Int,Int}}("something", 2, 1) + @test reverse(c) === cr - if VERSION < v"1.9-" - # can't reverse a named tuple or a dict - @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) + if VERSION < v"1.9-" + # can't reverse a named tuple or a dict + @test_throws MethodError reverse(Tangent{Foo}(; x=1.0, y=2.0)) - d = Dict(:x => 1, :y => 2.0) - cdict = Tangent{typeof(d),typeof(d)}(d) - @test_throws MethodError reverse(Tangent{Foo}()) - else - # These now work but do we care? + d = Dict(:x => 1, :y => 2.0) + cdict = Tangent{typeof(d),typeof(d)}(d) + @test_throws MethodError reverse(Tangent{Foo}()) + else + # These now work but do we care? + end end - end - @testset "unset properties" begin - @test Tangent{Foo}(; x=1.4).y === ZeroTangent() - end + @testset "unset properties" begin + @test Tangent{Foo}(; x=1.4).y === ZeroTangent() + end - @testset "conj" begin - @test conj(Tangent{Foo}(; x=2.0 + 3.0im)) == Tangent{Foo}(; x=2.0 - 3.0im) - @test ==( - conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), Tangent{Tuple{Float64}}(2.0 - 3.0im) - ) - @test ==( - conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), - Tangent{Dict}(Dict(4 => 2.0 + -3.0im)), - ) - end + @testset "conj" begin + @test conj(Tangent{Foo}(; x=2.0 + 3.0im)) == Tangent{Foo}(; x=2.0 - 3.0im) + @test ==( + conj(Tangent{Tuple{Float64}}(2.0 + 3.0im)), Tangent{Tuple{Float64}}(2.0 - 3.0im) + ) + @test ==( + conj(Tangent{Dict}(Dict(4 => 2.0 + 3.0im))), + Tangent{Dict}(Dict(4 => 2.0 + -3.0im)), + ) + end - @testset "canonicalize" begin - # Testing iterate via collect - @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) + @testset "canonicalize" begin + # Testing iterate via collect + @test ==(canonicalize(Tangent{Tuple{Float64}}(2.0)), Tangent{Tuple{Float64}}(2.0)) - @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) + @test ==(canonicalize(Tangent{Dict}(Dict(4 => 3))), Tangent{Dict}(Dict(4 => 3))) - # For structure it needs to match order and ZeroTangent() fill to match primal - CFoo = Tangent{Foo} - @test canonicalize(CFoo(; x=2.5, y=10)) == CFoo(; x=2.5, y=10) - @test canonicalize(CFoo(; y=10, x=2.5)) == CFoo(; x=2.5, y=10) - @test canonicalize(CFoo(; y=10)) == CFoo(; x=ZeroTangent(), y=10) + # For structure it needs to match order and ZeroTangent() fill to match primal + CFoo = Tangent{Foo} + @test canonicalize(CFoo(; x=2.5, y=10)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10, x=2.5)) == CFoo(; x=2.5, y=10) + @test canonicalize(CFoo(; y=10)) == CFoo(; x=ZeroTangent(), y=10) - @test_throws ArgumentError canonicalize(CFoo(; q=99.0, x=2.5)) + @test_throws ArgumentError canonicalize(CFoo(; q=99.0, x=2.5)) - @testset "unspecified primal type" begin - c1 = Tangent{Any}(; a=1, b=2) - c2 = Tangent{Any}(1, 2) - c3 = Tangent{Any}(Dict(4 => 3)) + @testset "unspecified primal type" begin + c1 = Tangent{Any}(; a=1, b=2) + c2 = Tangent{Any}(1, 2) + c3 = Tangent{Any}(Dict(4 => 3)) - @test c1 == canonicalize(c1) - @test c2 == canonicalize(c2) - @test c3 == canonicalize(c3) + @test c1 == canonicalize(c1) + @test c2 == canonicalize(c2) + @test c3 == canonicalize(c3) + end end - end - @testset "+ with other composites" begin - @testset "Structs" begin - CFoo = Tangent{Foo} - @test CFoo(; x=1.5) + CFoo(; x=2.5) == CFoo(; x=4.0) - @test CFoo(; y=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=2.5) - @test CFoo(; y=1.5, x=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=4.0) - end + @testset "+ with other composites" begin + @testset "Structs" begin + CFoo = Tangent{Foo} + @test CFoo(; x=1.5) + CFoo(; x=2.5) == CFoo(; x=4.0) + @test CFoo(; y=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=2.5) + @test CFoo(; y=1.5, x=1.5) + CFoo(; x=2.5) == CFoo(; y=1.5, x=4.0) + end - @testset "Tuples" begin - @test ==( - typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), Tangent{Tuple{},Tuple{}} - ) - @test ( - Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + - Tangent{Tuple{Float64,Float64}}(1.0, 1.0) - ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) - end + @testset "Tuples" begin + @test ==( + typeof(Tangent{Tuple{}}() + Tangent{Tuple{}}()), Tangent{Tuple{},Tuple{}} + ) + @test ( + Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + + Tangent{Tuple{Float64,Float64}}(1.0, 1.0) + ) == Tangent{Tuple{Float64,Float64}}(2.0, 3.0) + end - @testset "NamedTuples" begin - make_tangent(nt::NamedTuple) = Tangent{typeof(nt)}(; nt...) - t1 = make_tangent((; a=1.5, b=0.0)) - t2 = make_tangent((; a=0.0, b=2.5)) - t_sum = make_tangent((a=1.5, b=2.5)) - @test t1 + t2 == t_sum - end + @testset "NamedTuples" begin + make_tangent(nt::NamedTuple) = Tangent{typeof(nt)}(; nt...) + t1 = make_tangent((; a=1.5, b=0.0)) + t2 = make_tangent((; a=0.0, b=2.5)) + t_sum = make_tangent((a=1.5, b=2.5)) + @test t1 + t2 == t_sum + end - @testset "Dicts" begin - d1 = Tangent{Dict}(Dict(4 => 3.0, 3 => 2.0)) - d2 = Tangent{Dict}(Dict(4 => 3.0, 2 => 2.0)) - d_sum = Tangent{Dict}(Dict(4 => 3.0 + 3.0, 3 => 2.0, 2 => 2.0)) - @test d1 + d2 == d_sum + @testset "Dicts" begin + d1 = Tangent{Dict}(Dict(4 => 3.0, 3 => 2.0)) + d2 = Tangent{Dict}(Dict(4 => 3.0, 2 => 2.0)) + d_sum = Tangent{Dict}(Dict(4 => 3.0 + 3.0, 3 => 2.0, 2 => 2.0)) + @test d1 + d2 == d_sum + end + + @testset "Fields of type NotImplemented" begin + CFoo = Tangent{Foo} + a = CFoo(; x=1.5) + b = CFoo(; x=@not_implemented("")) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa CFoo + @test z.x isa ChainRulesCore.NotImplemented + end + + a = Tangent{Tuple}(1.5) + b = Tangent{Tuple}(@not_implemented("")) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa Tangent{Tuple} + @test first(z) isa ChainRulesCore.NotImplemented + end + + a = Tangent{NamedTuple{(:x,)}}(; x=1.5) + b = Tangent{NamedTuple{(:x,)}}(; x=@not_implemented("")) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa Tangent{NamedTuple{(:x,)}} + @test z.x isa ChainRulesCore.NotImplemented + end + + a = Tangent{Dict}(Dict(:x => 1.5)) + b = Tangent{Dict}(Dict(:x => @not_implemented(""))) + for (x, y) in ((a, b), (b, a), (b, b)) + z = x + y + @test z isa Tangent{Dict} + @test z[:x] isa ChainRulesCore.NotImplemented + end + end end - @testset "Fields of type NotImplemented" begin - CFoo = Tangent{Foo} - a = CFoo(; x=1.5) - b = CFoo(; x=@not_implemented("")) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa CFoo - @test z.x isa ChainRulesCore.NotImplemented + @testset "+ with Primals" begin + @testset "Structs" begin + @test Foo(3.5, 1.5) + Tangent{Foo}(; x=2.5) == Foo(6.0, 1.5) + @test Tangent{Foo}(; x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) + @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 end - a = Tangent{Tuple}(1.5) - b = Tangent{Tuple}(@not_implemented("")) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa Tangent{Tuple} - @test first(z) isa ChainRulesCore.NotImplemented + @testset "Tuples" begin + @test Tangent{Tuple{}}() + () == () + @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) + @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) end - a = Tangent{NamedTuple{(:x,)}}(; x=1.5) - b = Tangent{NamedTuple{(:x,)}}(; x=@not_implemented("")) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa Tangent{NamedTuple{(:x,)}} - @test z.x isa ChainRulesCore.NotImplemented + @testset "NamedTuple" begin + ntx = (; a=1.5) + @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) + + nty = (; a=1.5, b=0.5) + @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) end - a = Tangent{Dict}(Dict(:x => 1.5)) - b = Tangent{Dict}(Dict(:x => @not_implemented(""))) - for (x, y) in ((a, b), (b, a), (b, b)) - z = x + y - @test z isa Tangent{Dict} - @test z[:x] isa ChainRulesCore.NotImplemented + @testset "Dicts" begin + d_primal = Dict(4 => 3.0, 3 => 2.0) + d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) + @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) end end - end - @testset "+ with Primals" begin - @testset "Structs" begin - @test Foo(3.5, 1.5) + Tangent{Foo}(; x=2.5) == Foo(6.0, 1.5) - @test Tangent{Foo}(; x=2.5) + Foo(3.5, 1.5) == Foo(6.0, 1.5) - @test (@ballocated Bar(0.5) + Tangent{Bar}(; x=0.5)) == 0 - end + @testset "+ with Primals, with inner constructor" begin + value = StructWithInvariant(10.0) + diff = Tangent{StructWithInvariant}(; x=2.0, x2=6.0) - @testset "Tuples" begin - @test Tangent{Tuple{}}() + () == () - @test ((1.0, 2.0) + Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) == (2.0, 3.0) - @test (Tangent{Tuple{Float64,Float64}}(1.0, 1.0)) + (1.0, 2.0) == (2.0, 3.0) - end + @testset "with and without debug mode" begin + @assert ChainRulesCore.debug_mode() == false + @test_throws MethodError (value + diff) + @test_throws MethodError (diff + value) - @testset "NamedTuple" begin - ntx = (; a=1.5) - @test Tangent{typeof(ntx)}(; ntx...) + ntx == (; a=3.0) + ChainRulesCore.debug_mode() = true # enable debug mode + @test_throws ChainRulesCore.PrimalAdditionFailedException (value + diff) + @test_throws ChainRulesCore.PrimalAdditionFailedException (diff + value) + ChainRulesCore.debug_mode() = false # disable it again + end - nty = (; a=1.5, b=0.5) - @test Tangent{typeof(nty)}(; nty...) + nty == (; a=3.0, b=1.0) + # Now we define constuction for ChainRulesCore.jl's purposes: + # It is going to determine the root quanity of the invarient + function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) + x = (nt.x + nt.x2 / 2) / 2 + return StructWithInvariant(x) + end + @test value + diff == StructWithInvariant(12.5) + @test diff + value == StructWithInvariant(12.5) end - @testset "Dicts" begin - d_primal = Dict(4 => 3.0, 3 => 2.0) - d_tangent = Tangent{typeof(d_primal)}(Dict(4 => 5.0)) - @test d_primal + d_tangent == Dict(4 => 3.0 + 5.0, 3 => 2.0) + @testset "differential arithmetic" begin + c = Tangent{Foo}(; y=1.5, x=2.5) + + @test NoTangent() * c == NoTangent() + @test c * NoTangent() == NoTangent() + @test dot(NoTangent(), c) == NoTangent() + @test dot(c, NoTangent()) == NoTangent() + @test norm(Tangent{Foo}(; y=c.y, x=NoTangent())) == c.y + @test norm(NoTangent(), Inf) == 0 + + @test ZeroTangent() * c == ZeroTangent() + @test c * ZeroTangent() == ZeroTangent() + @test dot(ZeroTangent(), c) == ZeroTangent() + @test dot(c, ZeroTangent()) == ZeroTangent() + @test norm(ZeroTangent()) == 0 + @test norm(ZeroTangent(), 0.4) == 0 + + @test true * c === c + @test c * true === c + + t = @thunk 2 + @test t * c == 2 * c + @test c * t == c * 2 end - end - @testset "+ with Primals, with inner constructor" begin - value = StructWithInvariant(10.0) - diff = Tangent{StructWithInvariant}(; x=2.0, x2=6.0) + @testset "-Tangent" begin + t = Tangent{Foo}(; x=1.0, y=-2.0) + @test -t == Tangent{Foo}(; x=-1.0, y=2.0) + @test -1.0 * t == -t + end - @testset "with and without debug mode" begin - @assert ChainRulesCore.debug_mode() == false - @test_throws MethodError (value + diff) - @test_throws MethodError (diff + value) + @testset "scaling" begin + @test ( + 2 * Tangent{Foo}(; y=1.5, x=2.5) == + Tangent{Foo}(; y=3.0, x=5.0) == + Tangent{Foo}(; y=1.5, x=2.5) * 2 + ) + @test ( + 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == + Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == + Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 + ) + d = Tangent{Dict}(Dict(4 => 3.0)) + two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) + @test 2 * d == two_d == d * 2 - ChainRulesCore.debug_mode() = true # enable debug mode - @test_throws ChainRulesCore.PrimalAdditionFailedException (value + diff) - @test_throws ChainRulesCore.PrimalAdditionFailedException (diff + value) - ChainRulesCore.debug_mode() = false # disable it again + @test_throws MethodError [1, 2] * Tangent{Foo}(; y=1.5, x=2.5) + @test_throws MethodError [1, 2] * d + @test_throws MethodError Tangent{Foo}(; y=1.5, x=2.5) * @thunk [1 2; 3 4] end - # Now we define constuction for ChainRulesCore.jl's purposes: - # It is going to determine the root quanity of the invarient - function ChainRulesCore.construct(::Type{StructWithInvariant}, nt::NamedTuple) - x = (nt.x + nt.x2 / 2) / 2 - return StructWithInvariant(x) + @testset "iszero" begin + @test iszero(Tangent{Foo}()) + @test iszero(Tangent{Tuple{}}()) + @test iszero(Tangent{Foo}(; x=ZeroTangent())) + @test iszero(Tangent{Foo}(; y=0.0)) + @test iszero(Tangent{Foo}(; x=Tangent{Tuple{}}(), y=0.0)) + + @test !iszero(Tangent{Foo}(; y=3.0)) end - @test value + diff == StructWithInvariant(12.5) - @test diff + value == StructWithInvariant(12.5) - end - @testset "differential arithmetic" begin - c = Tangent{Foo}(; y=1.5, x=2.5) - - @test NoTangent() * c == NoTangent() - @test c * NoTangent() == NoTangent() - @test dot(NoTangent(), c) == NoTangent() - @test dot(c, NoTangent()) == NoTangent() - @test norm(Tangent{Foo}(; y=c.y, x=NoTangent())) == c.y - @test norm(NoTangent(), Inf) == 0 - - @test ZeroTangent() * c == ZeroTangent() - @test c * ZeroTangent() == ZeroTangent() - @test dot(ZeroTangent(), c) == ZeroTangent() - @test dot(c, ZeroTangent()) == ZeroTangent() - @test norm(ZeroTangent()) == 0 - @test norm(ZeroTangent(), 0.4) == 0 - - @test true * c === c - @test c * true === c - - t = @thunk 2 - @test t * c == 2 * c - @test c * t == c * 2 - end + @testset "show" begin + @test repr(Tangent{Foo}(; x=1)) == "Tangent{Foo}(x = 1,)" + # check for exact regex match not occurence( `^...$`) + # and allowing optional whitespace (`\s?`) + @test occursin( + r"^Tangent{Tuple{Int64,\s?Int64}}\(1,\s?2\)$", + repr(Tangent{Tuple{Int64,Int64}}(1, 2)), + ) - @testset "-Tangent" begin - t = Tangent{Foo}(; x=1.0, y=-2.0) - @test -t == Tangent{Foo}(; x=-1.0, y=2.0) - @test -1.0 * t == -t - end + @test repr(Tangent{Foo}()) == "Tangent{Foo}()" + end - @testset "scaling" begin - @test ( - 2 * Tangent{Foo}(; y=1.5, x=2.5) == - Tangent{Foo}(; y=3.0, x=5.0) == - Tangent{Foo}(; y=1.5, x=2.5) * 2 - ) - @test ( - 2 * Tangent{Tuple{Float64,Float64}}(2.0, 4.0) == - Tangent{Tuple{Float64,Float64}}(4.0, 8.0) == - Tangent{Tuple{Float64,Float64}}(2.0, 4.0) * 2 - ) - d = Tangent{Dict}(Dict(4 => 3.0)) - two_d = Tangent{Dict}(Dict(4 => 2 * 3.0)) - @test 2 * d == two_d == d * 2 + @testset "internals" begin + @testset "Can't do backing on primative type" begin + @test_throws Exception ChainRulesCore.backing(1.4) + end - @test_throws MethodError [1, 2] * Tangent{Foo}(; y=1.5, x=2.5) - @test_throws MethodError [1, 2] * d - @test_throws MethodError Tangent{Foo}(; y=1.5, x=2.5) * @thunk [1 2; 3 4] - end + @testset "Internals don't allocate a ton" begin + bk = (; x=1.0, y=2.0) + VERSION >= v"1.5" && + @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 - @testset "iszero" begin - @test iszero(Tangent{Foo}()) - @test iszero(Tangent{Tuple{}}()) - @test iszero(Tangent{Foo}(; x=ZeroTangent())) - @test iszero(Tangent{Foo}(; y=0.0)) - @test iszero(Tangent{Foo}(; x=Tangent{Tuple{}}(), y=0.0)) + # weaker version of the above (which should pass on all versions) + @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 + @test (@ballocated ChainRulesCore.elementwise_add($bk, $bk)) <= 48 + end + end - @test !iszero(Tangent{Foo}(; y=3.0)) + @testset "non-same-typed differential arithmetic" begin + nt = (; a=1, b=2.0) + c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) + @test nt + c == (; a=1, b=2.1) + end + + @testset "printing" begin + t5 = Tuple(rand(3)) + nt3 = (x=t5, y=t5, z=nothing) + tang = ProjectTo(nt3)(nt3) # moderately complicated Tangent + @test contains(sprint(show, tang), "...}(x = Tangent") # gets shortened + @test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole + end end - @testset "show" begin - @test repr(Tangent{Foo}(; x=1)) == "Tangent{Foo}(x = 1,)" - # check for exact regex match not occurence( `^...$`) - # and allowing optional whitespace (`\s?`) - @test occursin( - r"^Tangent{Tuple{Int64,\s?Int64}}\(1,\s?2\)$", - repr(Tangent{Tuple{Int64,Int64}}(1, 2)), + @testset "MutableTangent" begin + mutable struct MDemo + x::Float64 + end + function ChainRulesCore.frule( + (_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x ) - - @test repr(Tangent{Foo}()) == "Tangent{Foo}()" - end - - @testset "internals" begin - @testset "Can't do backing on primative type" begin - @test_throws Exception ChainRulesCore.backing(1.4) + y = setfield!(obj, field, x) + ẏ = setproperty!(ȯbj, field, ẋ) + return y, ẏ end - @testset "Internals don't allocate a ton" begin - bk = (; x=1.0, y=2.0) - VERSION >= v"1.5" && - @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 32 - - # weaker version of the above (which should pass on all versions) - @test (@ballocated(ChainRulesCore.construct($Foo, $bk))) <= 48 - @test (@ballocated ChainRulesCore.elementwise_add($bk, $bk)) <= 48 + @testset "usecase" begin + obj = MDemo(99.0) + ∂obj = MutableTangent{MDemo}(; x=1.5) + frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) + @test ∂obj.x == 10.0 + @test obj.x == 95.0 + + frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0) + @test ∂obj.x == 20.0 + @test getproperty(∂obj, 1) == 20.0 + @test obj.x == 96.0 end - end - @testset "non-same-typed differential arithmetic" begin - nt = (; a=1, b=2.0) - c = Tangent{typeof(nt)}(; a=NoTangent(), b=0.1) - @test nt + c == (; a=1, b=2.1) - end - - @testset "printing" begin - t5 = Tuple(rand(3)) - nt3 = (x=t5, y=t5, z=nothing) - tang = ProjectTo(nt3)(nt3) # moderately complicated Tangent - @test contains(sprint(show, tang), "...}(x = Tangent") # gets shortened - @test contains(sprint(show, tang), sprint(show, tang.x)) # inner piece appears whole - end -end + @testset "== and hash" begin + @test MutableTangent{MDemo}(; x=1.0f0) == MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1.0f0) + @test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) + @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0) -@testset "MutableTangent" begin - mutable struct MDemo - x::Float64 - end - function ChainRulesCore.frule( - (_, ȯbj, _, ẋ), ::typeof(setfield!), obj::MDemo, field, x - ) - y = setfield!(obj, field, x) - ẏ = setproperty!(ȯbj, field, ẋ) - return y, ẏ - end + nt = (; x=1.0) + @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0) - @testset "usecase" begin - obj = MDemo(99.0) - ∂obj = MutableTangent{MDemo}(; x=1.5) - frule((NoTangent(), ∂obj, NoTangent(), 10.0), setfield!, obj, :x, 95.0) - @test ∂obj.x == 10.0 - @test obj.x == 95.0 - - frule((NoTangent(), ∂obj, NoTangent(), 20.0), setfield!, obj, 1, 96.0) - @test ∂obj.x == 20.0 - @test getproperty(∂obj, 1) == 20.0 - @test obj.x == 96.0 - end - - @testset "== and hash" begin - @test MutableTangent{MDemo}(; x=1.0f0) == MutableTangent{MDemo}(; x=1.0) - @test MutableTangent{MDemo}(; x=1.0) == MutableTangent{MDemo}(; x=1.0f0) - @test MutableTangent{MDemo}(; x=2.0) != MutableTangent{MDemo}(; x=1.0) - @test MutableTangent{MDemo}(; x=1.0) != MutableTangent{MDemo}(; x=2.0) - - nt = (; x=1.0) - @test MutableTangent{typeof(nt)}(nt) != MutableTangent{MDemo}(; x=1.0) - - @test hash(MutableTangent{MDemo}(; x=1.0f0)) == hash(MutableTangent{MDemo}(; x=1.0)) - end + @test hash(MutableTangent{MDemo}(; x=1.0f0)) == hash(MutableTangent{MDemo}(; x=1.0)) + end - @testset "Mutation" begin - v = MutableTangent{MFoo}(; x=1.5, y=2.4) - v.x = 1.6 - @test v == MutableTangent{MFoo}(; x=1.6, y=2.4) - v.y = [1.0, 2.0] # change type, because primal can change type - @test v == MutableTangent{MFoo}(; x=1.6, y=[1.0, 2.0]) + @testset "Mutation" begin + v = MutableTangent{MFoo}(; x=1.5, y=2.4) + v.x = 1.6 + @test v == MutableTangent{MFoo}(; x=1.6, y=2.4) + v.y = [1.0, 2.0] # change type, because primal can change type + @test v == MutableTangent{MFoo}(; x=1.6, y=[1.0, 2.0]) + end end -end -@testset "map" begin - @testset "Tangent" begin - ∂foo = Tangent{Foo}(; x=1.5, y=2.4) - @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0, y=4.8) + @testset "map" begin + @testset "Tangent" begin + ∂foo = Tangent{Foo}(; x=1.5, y=2.4) + @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0, y=4.8) - ∂foo = Tangent{Foo}(; x=1.5) - @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0) - end - @testset "MutableTangent" begin - ∂foo = MutableTangent{MFoo}(; x=1.5, y=2.4) - ∂foo2 = map(v -> 2 * v, ∂foo) - @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=4.8) - # Check can still be mutated to new typ - ∂foo2.y = [1.0, 2.0] - @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=[1.0, 2.0]) + ∂foo = Tangent{Foo}(; x=1.5) + @test map(v -> 2 * v, ∂foo) == Tangent{Foo}(; x=3.0) + end + @testset "MutableTangent" begin + ∂foo = MutableTangent{MFoo}(; x=1.5, y=2.4) + ∂foo2 = map(v -> 2 * v, ∂foo) + @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=4.8) + # Check can still be mutated to new typ + ∂foo2.y = [1.0, 2.0] + @test ∂foo2 == MutableTangent{MFoo}(; x=3.0, y=[1.0, 2.0]) + end end end \ No newline at end of file From ade0c3de71d63644c6f781b8eded091bd9297e7c Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 29 Dec 2023 15:22:20 +0800 Subject: [PATCH 24/34] Support types that have no tangent space in zero_tangent --- src/tangent_types/abstract_zero.jl | 10 +++++-- test/tangent_types/abstract_zero.jl | 45 +++++++++++++++++++---------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index a0f8d1f82..8e31ae492 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -96,10 +96,11 @@ struct NoTangent <: AbstractZero end zero_tangent(primal) This returns an appropriate zero tangent suitable for accumulating tangents of the primal. -For mutable composites types this is a structural []`MutableTangent`](@ref) +For mutable composites types this is a structural [`MutableTangent`](@ref) For `Array`s, it is applied recursively for each element. -For immutable types, this is simply [`ZeroTangent()`](@ref) as accumulation is default out-of-place for contexts where mutation does not apply. -(Where mutation is not to be supported even for mutable types, then [`ZeroTangent()`](@ref) should be used for everything) +For other types, in particular immutable types, we do not make promises beyond that it will be `iszero` +and suitable for accumulating against. +In general though, it is more likely to produce a structural tangent. !!! warning Exprimental `zero_tangent`is an experimental feature, and is part of the mutation support featureset. @@ -110,7 +111,10 @@ function zero_tangent end zero_tangent(x::Number) = zero(x) +zero_tangent(::Type) = NoTangent() + @generated function zero_tangent(primal) + fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero. zfield_exprs = map(fieldnames(primal)) do fname fval = :( if isdefined(primal, $(QuoteNode(fname))) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 81114511a..960e88d99 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -162,25 +162,38 @@ end @testset "zero_tangent" begin - @test zero_tangent(1) === 0 - @test zero_tangent(1.0) === 0.0 - mutable struct MutDemo - x::Float64 - end - struct Demo - x::Float64 - end - @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} - @test iszero(zero_tangent(MutDemo(1.5))) + @testset "basics" begin + @test zero_tangent(1) === 0 + @test zero_tangent(1.0) === 0.0 + mutable struct MutDemo + x::Float64 + end + struct Demo + x::Float64 + end + @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} + @test iszero(zero_tangent(MutDemo(1.5))) - @test zero_tangent((; a=1)) isa Tangent{typeof((; a = 1))} - @test zero_tangent(Demo(1.2)) isa Tangent{Demo} - @test zero_tangent(Demo(1.2)).x === 0.0 + @test zero_tangent((; a=1)) isa Tangent{typeof((; a = 1))} + @test zero_tangent(Demo(1.2)) isa Tangent{Demo} + @test zero_tangent(Demo(1.2)).x === 0.0 - @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] - @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] + @test zero_tangent([1.0, 2.0]) == [0.0, 0.0] + @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] + + @test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0) + end + + @testset "Weird types" begin + @test iszero(zero_tangent(typeof(Int))) # primative type + @test iszero(zero_tangent(typeof(Base.RefValue))) # struct + @test iszero(zero_tangent(Vector)) # UnionAll + @test iszero(zero_tangent(Union{Int, Float64})) # Union + @test iszero(zero_tangent(:abc)) + @test iszero(zero_tangent("abc")) + @test iszero(zero_tangent(sin)) + end - @test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0) @testset "undef elements Vector" begin x = Vector{Vector{Float64}}(undef, 3) x[2] = [1.0, 2.0] From e068cb6138337776a9bd6ab06af033a3a9ae50f1 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 16 Jan 2024 13:59:27 +0800 Subject: [PATCH 25/34] define zero_tangent for Tangent --- src/ChainRulesCore.jl | 2 +- src/tangent_types/abstract_zero.jl | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index bda392497..286f71db2 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -19,9 +19,9 @@ export StructuralTangent, Tangent, MutableTangent, NoTangent, InplaceableThunk, include("debug_mode.jl") include("tangent_types/abstract_tangent.jl") +include("tangent_types/structural_tangent.jl") include("tangent_types/abstract_zero.jl") include("tangent_types/thunks.jl") -include("tangent_types/structural_tangent.jl") include("tangent_types/notimplemented.jl") include("tangent_arithmetic.jl") diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 8e31ae492..86ed92523 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -113,6 +113,9 @@ zero_tangent(x::Number) = zero(x) zero_tangent(::Type) = NoTangent() +zero_tangent(x::Tangent) = ZeroTangent() +# TODO: zero_tangent(x::MutableTangent) + @generated function zero_tangent(primal) fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero. zfield_exprs = map(fieldnames(primal)) do fname From 45de6a737961c4f6fb546b2a5c1c0c34eb7daccb Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 17 Jan 2024 11:43:59 +0800 Subject: [PATCH 26/34] Add structural zero tangent code for higher order --- src/tangent_types/abstract_zero.jl | 11 +++++++++-- test/tangent_types/abstract_zero.jl | 8 ++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 86ed92523..61fc05b6f 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -113,8 +113,15 @@ zero_tangent(x::Number) = zero(x) zero_tangent(::Type) = NoTangent() -zero_tangent(x::Tangent) = ZeroTangent() -# TODO: zero_tangent(x::MutableTangent) +function zero_tangent(x::MutableTangent{P}) where P + zb = backing(zero_tangent(backing(x))) + return MutableTangent{P}(zb) +end + +function zero_tangent(x::Tangent{P}) where P + zb = backing(zero_tangent(backing(x))) + return Tangent{P, typeof(zb)}(zb) +end @generated function zero_tangent(primal) fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero. diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 960e88d99..a4df83ebf 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -182,6 +182,14 @@ end @test zero_tangent([[1.0, 2.0], [3.0]]) == [[0.0, 0.0], [0.0]] @test zero_tangent((1.0, 2.0)) == Tangent{Tuple{Float64,Float64}}(0.0, 0.0) + + # Higher order + # StructuralTangents are valid tangents for themselves (just like Numbers) + # and indeed we prefer that, otherwise higher order structural tangents are kinda + # nightmarishly complex types. + @test zero_tangent(zero_tangent(Demo(1.5))) == zero_tangent(Demo(1.5)) + @test zero_tangent(zero_tangent((1.5, 2.5))) == Tangent{Tuple{Float64, Float64}}(0.0, 0.0) + @test zero_tangent(zero_tangent(MutDemo(1.5))) == zero_tangent(MutDemo(1.5)) end @testset "Weird types" begin From 780ed05616a6c042b19f083ce78391aa1399c081 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 17 Jan 2024 12:05:02 +0800 Subject: [PATCH 27/34] Formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/abstract_zero.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 61fc05b6f..d4f17d852 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -113,7 +113,7 @@ zero_tangent(x::Number) = zero(x) zero_tangent(::Type) = NoTangent() -function zero_tangent(x::MutableTangent{P}) where P +function zero_tangent(x::MutableTangent{P}) where {P} zb = backing(zero_tangent(backing(x))) return MutableTangent{P}(zb) end From 2795872286a5516853d7800162c30900dffababc Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 17 Jan 2024 13:32:57 +0800 Subject: [PATCH 28/34] overload show for mutable tangent --- src/tangent_types/structural_tangent.jl | 5 ++++- test/tangent_types/structural_tangent.jl | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index c7ae1b1b5..7730a6215 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -334,7 +334,10 @@ Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) -function Base.show(io::IO, tangent::Tangent{P}) where {P} +function Base.show(io::IO, tangent::StructuralTangent{P}) where {P} + if tangent isa MutableTangent + print(io, "Mutable") + end print(io, "Tangent{") str = sprint(show, P, context = io) i = findfirst('{', str) diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 16d702c14..b902e89e4 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -398,6 +398,11 @@ end ) @test repr(Tangent{Foo}()) == "Tangent{Foo}()" + + @test ==( + repr(MutableTangent{MFoo}((;x=1.5, y=[1.0, 2.0]))), + "MutableTangent{MFoo}(x = 1.5, y = [1.0, 2.0])" + ) end @testset "internals" begin From 7e9e778d0c700fcece2a84e3d880d8d8d1f57af3 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 19 Jan 2024 12:09:12 +0800 Subject: [PATCH 29/34] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/abstract_zero.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index d4f17d852..15aa00d42 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -118,9 +118,9 @@ function zero_tangent(x::MutableTangent{P}) where {P} return MutableTangent{P}(zb) end -function zero_tangent(x::Tangent{P}) where P +function zero_tangent(x::Tangent{P}) where {P} zb = backing(zero_tangent(backing(x))) - return Tangent{P, typeof(zb)}(zb) + return Tangent{P,typeof(zb)}(zb) end @generated function zero_tangent(primal) From e912e46ba64c417c644bf75b18d2379680c5ece3 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 23 Jan 2024 17:08:56 +0800 Subject: [PATCH 30/34] move show code to `Common` area --- src/tangent_types/structural_tangent.jl | 53 +++++++++++++------------ 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 7730a6215..a469d9f1b 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -140,6 +140,33 @@ function Base.map(f, tangent::StructuralTangent{P}) where {P} end end + +function Base.show(io::IO, tangent::StructuralTangent{P}) where {P} + if tangent isa MutableTangent + print(io, "Mutable") + end + print(io, "Tangent{") + str = sprint(show, P, context = io) + i = findfirst('{', str) + if isnothing(i) + print(io, str) + else # for Tangent{T{A,B,C}}(stuff), print {A,B,C} in grey, and trim this part if longer than a line: + print(io, str[1:prevind(str, i)]) + if length(str) < 80 + printstyled(io, str[i:end], color=:light_black) + else + printstyled(io, str[i:prevind(str, 80)], "...", color=:light_black) + end + end + print(io, "}") + if isempty(backing(tangent)) + print(io, "()") # so it doesn't show `NamedTuple()` + else + # allow Tuple or NamedTuple `show` to do the rendering of brackets etc + show(io, backing(tangent)) + end +end + """ backing(x) @@ -334,32 +361,6 @@ Base.:(==)(a::Tangent{P}, b::Tangent{Q}) where {P,Q} = false Base.hash(a::Tangent, h::UInt) = Base.hash(backing(canonicalize(a)), h) -function Base.show(io::IO, tangent::StructuralTangent{P}) where {P} - if tangent isa MutableTangent - print(io, "Mutable") - end - print(io, "Tangent{") - str = sprint(show, P, context = io) - i = findfirst('{', str) - if isnothing(i) - print(io, str) - else # for Tangent{T{A,B,C}}(stuff), print {A,B,C} in grey, and trim this part if longer than a line: - print(io, str[1:prevind(str, i)]) - if length(str) < 80 - printstyled(io, str[i:end], color=:light_black) - else - printstyled(io, str[i:prevind(str, 80)], "...", color=:light_black) - end - end - print(io, "}") - if isempty(backing(tangent)) - print(io, "()") # so it doesn't show `NamedTuple()` - else - # allow Tuple or NamedTuple `show` to do the rendering of brackets etc - show(io, backing(tangent)) - end -end - Base.iszero(::Tangent{<:,NamedTuple{}}) = true Base.iszero(::Tangent{<:,Tuple{}}) = true From e478e7fd201cb568d648d9658ed3656ad70fa672 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 23 Jan 2024 17:22:36 +0800 Subject: [PATCH 31/34] docs more consistent --- .../superpowers/mutation_support.md | 23 +++++++++++++------ src/tangent_types/abstract_zero.jl | 3 ++- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/docs/src/rule_author/superpowers/mutation_support.md b/docs/src/rule_author/superpowers/mutation_support.md index 55629166a..a4dec8ab8 100644 --- a/docs/src/rule_author/superpowers/mutation_support.md +++ b/docs/src/rule_author/superpowers/mutation_support.md @@ -1,6 +1,6 @@ # Mutation Support -ChainRulesCore.jl offers experimental support for mutation, targetting use in forward mode AD. +ChainRulesCore.jl offers experimental support for mutation, targeting use in forward mode AD. (Mutation support in reverse mode AD is more complicated and will likely require more changes to the interface) !!! warning "Experimental" @@ -17,18 +17,23 @@ It is required to be a structural tangent, having one tangent for each field of Technically, not all `mutable struct`s need to use `MutableTangent` to represent their tangents. Just like not all `struct`s need to use `Tangent`s. Common examples away from this are natural tangent types like for arrays. -However, if one is setting up to use a custom tangent type for this it is surficiently off the beated path that we can not provide much guidance. +However, if one is setting up to use a custom tangent type for this it is sufficiently off the beaten path that we can not provide much guidance. ## `zero_tangent` The [`zero_tangent`](@ref) function functions to give you a zero (i.e. additive identity) for any primal value. The [`ZeroTangent`](@ref) type also does this. -The difference is that [`zero_tangent`](@ref) is (where possible) a full structural tangent mirroring the structure of the primal. +The difference is that [`zero_tangent`](@ref) is in general full structural tangent mirroring the structure of the primal. +To be technical the promise of [`zero_tangent`](@ref) is that it will be a value that supports mutation. +However, in practice[^1] this is achieved through in a structural tangent For mutation support this is important, since it means that there is mutable memory available in the tangent to be mutated when the primal changes. To support this you thus need to make sure your zeros are created in various places with [`zero_tangent`](@ref) rather than []`ZeroTangent`](@ref). -It is also useful for reasons of type stability, since it is always a structural tangent. -For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not. + + +It is also useful for reasons of type stability, since it forces a consistent type (generally a structural tangent) for any given primal type. +For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not, +and to process the output of `frule`s to convert [`ZeroTangent`](@ref) into corresponding [`zero_tangent`](@ref)s. ## Writing a frule for a mutating function It is relatively straight forward to write a frule for a mutating function. @@ -41,7 +46,7 @@ There are a few key points to follow: ### Example For example, consider the primal function with: 1. takes two `Ref`s -2. doubles the first one inplace +2. doubles the first one in place 3. overwrites the second one's value with the literal 5.0 4. returns the first one @@ -70,4 +75,8 @@ function ChainRulesCore.frule((ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::B end ``` -Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct. \ No newline at end of file +Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct. + +[^1]: + Further, it is hard to achieve this promise of allowing mutation to be supported without returning a structural tangent. + Except in the special case of where the struct is not mutable and has no nested fields that are mutable. \ No newline at end of file diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 15aa00d42..f921db29d 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -100,7 +100,8 @@ For mutable composites types this is a structural [`MutableTangent`](@ref) For `Array`s, it is applied recursively for each element. For other types, in particular immutable types, we do not make promises beyond that it will be `iszero` and suitable for accumulating against. -In general though, it is more likely to produce a structural tangent. +For types without a tangent space (e.g. singleton structs) this returns `NoTangent()`. +In general, it is more likely to produce a structural tangent. !!! warning Exprimental `zero_tangent`is an experimental feature, and is part of the mutation support featureset. From 7d95866e3c10420b190a8cb2248f1d67fa0615e1 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 23 Jan 2024 17:31:58 +0800 Subject: [PATCH 32/34] Update src/tangent_types/structural_tangent.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/structural_tangent.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index a469d9f1b..04d93800f 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -94,7 +94,6 @@ struct MutableTangent{P,F} <: StructuralTangent{P} end end - #################################################################### # StructuralTangent Common From fe63c3380b591360b3b14e95582e126cc8636417 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 23 Jan 2024 17:32:32 +0800 Subject: [PATCH 33/34] Update test/tangent_types/structural_tangent.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/tangent_types/structural_tangent.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index b902e89e4..c177b05f4 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -400,8 +400,8 @@ end @test repr(Tangent{Foo}()) == "Tangent{Foo}()" @test ==( - repr(MutableTangent{MFoo}((;x=1.5, y=[1.0, 2.0]))), - "MutableTangent{MFoo}(x = 1.5, y = [1.0, 2.0])" + repr(MutableTangent{MFoo}((; x=1.5, y=[1.0, 2.0]))), + "MutableTangent{MFoo}(x = 1.5, y = [1.0, 2.0])", ) end From 5fbbe5ba3ccb52b0aaf7411f1099d3a7c5460c1d Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 24 Jan 2024 10:17:54 +0800 Subject: [PATCH 34/34] Handle circular references with-in mutable structs format self-refrential (squash me into prev) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/tangent_types/abstract_zero.jl | 80 ++++++++++++++++--------- src/tangent_types/structural_tangent.jl | 19 ++++++ test/tangent_types/abstract_zero.jl | 39 +++++++++++- 3 files changed, 110 insertions(+), 28 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index f921db29d..e51a99f09 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -93,7 +93,7 @@ arguments. struct NoTangent <: AbstractZero end """ - zero_tangent(primal) + zero_tangent(primal, _cache=nothing) This returns an appropriate zero tangent suitable for accumulating tangents of the primal. For mutable composites types this is a structural [`MutableTangent`](@ref) @@ -107,55 +107,77 @@ In general, it is more likely to produce a structural tangent. `zero_tangent`is an experimental feature, and is part of the mutation support featureset. While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. Exactly how it should be used (e.g. is it forward-mode only?) + +The `_cache=nothing` is an internal implementation detail that the user should never need to set. +(It is used to hold references to tangents for that might appear in self-referential structures) """ function zero_tangent end -zero_tangent(x::Number) = zero(x) +zero_tangent(x::Number, _cache=nothing) = zero(x) -zero_tangent(::Type) = NoTangent() +zero_tangent(::Type, _cache=nothing) = NoTangent() -function zero_tangent(x::MutableTangent{P}) where {P} - zb = backing(zero_tangent(backing(x))) +function zero_tangent(x::MutableTangent{P}, _cache=nothing) where {P} + zb = backing(zero_tangent(backing(x), _cache)) return MutableTangent{P}(zb) end -function zero_tangent(x::Tangent{P}) where {P} - zb = backing(zero_tangent(backing(x))) +function zero_tangent(x::Tangent{P}, _cache=nothing) where {P} + zb = backing(zero_tangent(backing(x), _cache)) return Tangent{P,typeof(zb)}(zb) end -@generated function zero_tangent(primal) +@generated function zero_tangent(primal, _cache=nothing) fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero. zfield_exprs = map(fieldnames(primal)) do fname - fval = :( + :( if isdefined(primal, $(QuoteNode(fname))) - zero_tangent(getfield(primal, $(QuoteNode(fname)))) + zero_tangent(getfield(primal, $(QuoteNode(fname))), _cache) else # This is going to be potentially bad, but that's what they get for not giving us a primal # This will never me mutated inplace, rather it will alway be replaced with an actual value first ZeroTangent() end ) - Expr(:kw, fname, fval) end return if has_mutable_tangent(primal) - any_mask = map(fieldnames(primal), fieldtypes(primal)) do fname, ftype - # If it is is unassigned, or if it doesn't have a concrete type, let it take any value for its tangent - fdef = :(!isdefined(primal, $(QuoteNode(fname))) || !isconcretetype($ftype)) - Expr(:kw, fname, fdef) + # This is a little complex because we need to support-self referential types + # So we need to: + # 1. create the tangent, + # 2. put it in the cache + # 3. Do all the calls to create the zeros for the fields giving them that cache) + # 4. put those zeros into the object + tangent_types = map(guess_zero_tangent_type, fieldtypes(primal)) + is_defined_mask = Expr(:tuple, map(fieldnames(primal)) do fname + :(isdefined(primal, $(QuoteNode(fname)))) + end...) + + quote + isnothing(_cache) && (_cache = IdDict()) + found_tangent = get(_cache, primal, nothing) + !isnothing(found_tangent) && return found_tangent + + # Now we need to put into the cache a placeholder tangent so we can construct our fields using that cache + # then put those fields into the placeholder + tangent = $_MutableTangent(Val{$primal}(), $is_defined_mask, $tangent_types) + _cache[primal] = tangent + $( + map(fieldnames(primal), zfield_exprs) do fname, fval_expr + :(setproperty!(tangent, $(QuoteNode(fname)), $fval_expr)) + end... + ) + return tangent end - :($MutableTangent{$primal}( - $(Expr(:tuple, Expr(:parameters, any_mask...))), - $(Expr(:tuple, Expr(:parameters, zfield_exprs...))), - )) else - :($Tangent{$primal}($(Expr(:parameters, zfield_exprs...)))) + :($Tangent{$primal}($(Expr(:parameters, Expr.(:kw, fieldnames(primal), zfield_exprs)...)))) end end -zero_tangent(primal::Tuple) = Tangent{typeof(primal)}(map(zero_tangent, primal)...) +function zero_tangent(primal::Tuple, _cache=nothing) + return Tangent{typeof(primal)}(map(x -> zero_tangent(x, _cache), primal)...) +end -function zero_tangent(x::Array{P,N}) where {P,N} +function zero_tangent(x::Array{P,N}, _cache=nothing) where {P,N} if (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) return map(zero_tangent, x) end @@ -165,16 +187,20 @@ function zero_tangent(x::Array{P,N}) where {P,N} y = Array{guess_zero_tangent_type(P),N}(undef, size(x)...) @inbounds for n in eachindex(y) if isassigned(x, n) - y[n] = zero_tangent(x[n]) + y[n] = zero_tangent(x[n], _cache) end end return y end -# Sad heauristic methods we need because of unassigned values -guess_zero_tangent_type(::Type{T}) where {T<:Number} = T -guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T))) +# Sad heauristic methods +#guess_zero_tangent_type(::Type{T}) where {T<:Number} = T +#guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T))) function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} return Array{guess_zero_tangent_type(T),N} end -guess_zero_tangent_type(T::Type) = Any \ No newline at end of file + +# The following will fall back to `Any` if it is hard to infer +function guess_zero_tangent_type(::Type{T}) where {T} + return Core.Compiler.return_type(zero_tangent, Tuple{T}) +end \ No newline at end of file diff --git a/src/tangent_types/structural_tangent.jl b/src/tangent_types/structural_tangent.jl index 04d93800f..71228485a 100644 --- a/src/tangent_types/structural_tangent.jl +++ b/src/tangent_types/structural_tangent.jl @@ -54,6 +54,7 @@ struct Tangent{P,T} <: StructuralTangent{P} end end +function _MutableTangent end """ MutableTangent{P}(fields) <: StructuralTangent{P} <: AbstractTangent @@ -73,6 +74,23 @@ It itself is also mutable. struct MutableTangent{P,F} <: StructuralTangent{P} backing::F + # Uninitialized constructor + global function _MutableTangent(::Val{P}, is_defined_mask, tangent_types) where {P} + backing_vals = map(is_defined_mask, tangent_types) do is_def, tangent_type + ref = if !is_def + Ref{Union{ZeroTangent, tangent_type}} # allow a Zero which will be used for uninitialized values + else + Ref{tangent_type} + end + return ref() # undefined, but it will be filled later + end + backing = NamedTuple{fieldnames(P)}(backing_vals) + return new{P, typeof(backing)}(backing) + end + + # TODO: are the following two correct? + # Are they useful? + # The place they are used is just `map`, maybe we should instead just copy types the thing being mapped? function MutableTangent{P}( any_mask::NamedTuple{names, <:NTuple{<:Any, Bool}}, fvals::NamedTuple{names} ) where {names, P} @@ -88,6 +106,7 @@ struct MutableTangent{P,F} <: StructuralTangent{P} return new{P, typeof(backing)}(backing) end + function MutableTangent{P}(fvals) where P any_mask = NamedTuple{fieldnames(P)}((!isconcretetype).(fieldtypes(P))) return MutableTangent{P}(any_mask, fvals) diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index a4df83ebf..741d497b0 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -174,7 +174,7 @@ end @test zero_tangent(MutDemo(1.5)) isa MutableTangent{MutDemo} @test iszero(zero_tangent(MutDemo(1.5))) - @test zero_tangent((; a=1)) isa Tangent{typeof((; a = 1))} + @test zero_tangent((; a=1.3)) isa Tangent{typeof((; a = 1.3))} @test zero_tangent(Demo(1.2)) isa Tangent{Demo} @test zero_tangent(Demo(1.2)).x === 0.0 @@ -275,4 +275,41 @@ end @test d.z == [2.0, 3.0] @test d.z isa SubArray end + + @testset "cyclic references" begin + mutable struct Link + data::Float64 + next::Link + Link(data) = new(data) + end + + lk = Link(1.5) + lk.next = lk + + d = zero_tangent(lk) + @test d.data == 0.0 + @test d.next === d + + # The following two cases are broken + # We hope they are not too significant, because in general if you AD step by step they should work + # (as should the one above so maybe we should get rid of this extra complex logic) + # It's only a problem if you first do the multistep build then `zero_tangent` rather than `zero_tangent` at the constructor. + + # Idea: check if `!isbitstype` only if so do we need to worry about caching etc + struct CarryingArray + x::Vector + end + ca = CarryingArray(Any[1.5]) + push!(ca.x, ca) + @test_broken d_ca = zero_tangent(ca) + @test_broken d_ca[1] == 0.0 + @test_broken d_ca[2] === _ca + + # Idea: check if typeof(xs) <: eltype(xs), if so need to cache it before computing + xs = Any[1.5] + push!(xs, xs) + @test_broken d_xs = zero_tangent(xs) + @test_broken d_xs[1] == 0.0 + @test_broken d_xs[2] == d_xs + end end