-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: handle zero_tangent in presence of cyclic structures (v1) #654
Changes from all commits
7f15d11
c9b4938
a51e51e
93af90b
418b5ce
724ba1b
c9a65df
06b51f5
ed3aa1d
f45fbc7
b2bdb26
0438217
4852c91
0f82019
5574691
e9cc221
baea9d3
a27f1b6
4cfce0b
ad9a5af
8b3d525
c09ff91
59fc470
ade0c3d
e068cb6
45de6a7
780ed05
2795872
7e9e778
e912e46
e478e7f
7d95866
fe63c33
5fbbe5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Mutation Support | ||
|
||
ChainRulesCore.jl offers experimental support for mutation, targeting use in forward mode AD. | ||
(Mutation support in reverse mode AD is more complicated and will likely require more changes to the interface) | ||
|
||
!!! warning "Experimental" | ||
This page documents an experimental feature. | ||
Expect breaking changes in minor versions while this remains. | ||
It is not suitable for general use unless you are prepared to modify how you are using it each minor release. | ||
It is thus suggested that if you are using it to use _tilde_ bounds on supported minor versions. | ||
|
||
|
||
## `MutableTangent` | ||
The [`MutableTangent`](@ref) type is designed to be a partner to the [`Tangent`](@ref) type, with specific support for being mutated in place. | ||
It is required to be a structural tangent, having one tangent for each field of the primal object. | ||
|
||
Technically, not all `mutable struct`s need to use `MutableTangent` to represent their tangents. | ||
Just like not all `struct`s need to use `Tangent`s. | ||
Common examples away from this are natural tangent types like for arrays. | ||
However, if one is setting up to use a custom tangent type for this it is sufficiently off the beaten path that we can not provide much guidance. | ||
|
||
## `zero_tangent` | ||
|
||
The [`zero_tangent`](@ref) function functions to give you a zero (i.e. additive identity) for any primal value. | ||
The [`ZeroTangent`](@ref) type also does this. | ||
The difference is that [`zero_tangent`](@ref) is in general full structural tangent mirroring the structure of the primal. | ||
To be technical the promise of [`zero_tangent`](@ref) is that it will be a value that supports mutation. | ||
However, in practice[^1] this is achieved through in a structural tangent | ||
For mutation support this is important, since it means that there is mutable memory available in the tangent to be mutated when the primal changes. | ||
To support this you thus need to make sure your zeros are created in various places with [`zero_tangent`](@ref) rather than []`ZeroTangent`](@ref). | ||
|
||
|
||
|
||
It is also useful for reasons of type stability, since it forces a consistent type (generally a structural tangent) for any given primal type. | ||
For this reason AD system implementors might chose to use this to create the tangent for all literal values they encounter, mutable or not, | ||
and to process the output of `frule`s to convert [`ZeroTangent`](@ref) into corresponding [`zero_tangent`](@ref)s. | ||
|
||
## Writing a frule for a mutating function | ||
It is relatively straight forward to write a frule for a mutating function. | ||
There are a few key points to follow: | ||
- There must be a mutable tangent input for every mutated primal input | ||
- When the primal value is changed, the corresponding change must be made to its tangent partner | ||
- When a value is returned, return its partnered tangent. | ||
|
||
|
||
### Example | ||
For example, consider the primal function with: | ||
1. takes two `Ref`s | ||
2. doubles the first one in place | ||
3. overwrites the second one's value with the literal 5.0 | ||
4. returns the first one | ||
|
||
|
||
```julia | ||
function foo!(a::Base.RefValue, b::Base.RefValue) | ||
a[] *= 2 | ||
b[] = 5.0 | ||
return a | ||
end | ||
``` | ||
|
||
The frule for this would be: | ||
```julia | ||
function ChainRulesCore.frule((ȧ, ḃ), ::typeof(foo!), a::Base.RefValue, b::Base.RefValue) | ||
@assert ȧ isa MutableTangent{typeof(a)} | ||
@assert ḃ isa MutableTangent{typeof(b)} | ||
|
||
a[] *= 2 | ||
ȧ.x *= 2 # `.x` is the field that lives behind RefValues | ||
|
||
b[]=5.0 | ||
ḃ.x = zero_tangent(5.0) # or since we know that the zero for a Float64 is zero could write `ḃ.x = 0.0` | ||
|
||
return a, ȧ | ||
end | ||
``` | ||
|
||
Then assuming the AD system does its part to makes sure you are indeed given mutable values to mutate (i.e. those `@assert`ions are true) then all is well and this rule will make mutation correct. | ||
|
||
[^1]: | ||
Further, it is hard to achieve this promise of allowing mutation to be supported without returning a structural tangent. | ||
Except in the special case of where the struct is not mutable and has no nested fields that are mutable. |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -91,3 +91,116 @@ arguments. | |||||||||||||||||||||
``` | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
struct NoTangent <: AbstractZero end | ||||||||||||||||||||||
|
||||||||||||||||||||||
""" | ||||||||||||||||||||||
zero_tangent(primal, _cache=nothing) | ||||||||||||||||||||||
|
||||||||||||||||||||||
This returns an appropriate zero tangent suitable for accumulating tangents of the primal. | ||||||||||||||||||||||
For mutable composites types this is a structural [`MutableTangent`](@ref) | ||||||||||||||||||||||
For `Array`s, it is applied recursively for each element. | ||||||||||||||||||||||
For other types, in particular immutable types, we do not make promises beyond that it will be `iszero` | ||||||||||||||||||||||
and suitable for accumulating against. | ||||||||||||||||||||||
For types without a tangent space (e.g. singleton structs) this returns `NoTangent()`. | ||||||||||||||||||||||
In general, it is more likely to produce a structural tangent. | ||||||||||||||||||||||
|
||||||||||||||||||||||
!!! warning Exprimental | ||||||||||||||||||||||
`zero_tangent`is an experimental feature, and is part of the mutation support featureset. | ||||||||||||||||||||||
While this notice remains it may have changes in behavour, and interface in any _minor_ version of ChainRulesCore. | ||||||||||||||||||||||
Exactly how it should be used (e.g. is it forward-mode only?) | ||||||||||||||||||||||
|
||||||||||||||||||||||
The `_cache=nothing` is an internal implementation detail that the user should never need to set. | ||||||||||||||||||||||
(It is used to hold references to tangents for that might appear in self-referential structures) | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
function zero_tangent end | ||||||||||||||||||||||
|
||||||||||||||||||||||
zero_tangent(x::Number, _cache=nothing) = zero(x) | ||||||||||||||||||||||
|
||||||||||||||||||||||
zero_tangent(::Type, _cache=nothing) = NoTangent() | ||||||||||||||||||||||
|
||||||||||||||||||||||
function zero_tangent(x::MutableTangent{P}, _cache=nothing) where {P} | ||||||||||||||||||||||
zb = backing(zero_tangent(backing(x), _cache)) | ||||||||||||||||||||||
return MutableTangent{P}(zb) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
function zero_tangent(x::Tangent{P}, _cache=nothing) where {P} | ||||||||||||||||||||||
zb = backing(zero_tangent(backing(x), _cache)) | ||||||||||||||||||||||
return Tangent{P,typeof(zb)}(zb) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
@generated function zero_tangent(primal, _cache=nothing) | ||||||||||||||||||||||
fieldcount(primal) == 0 && return NoTangent() # no tangent space at all, no need for structural zero. | ||||||||||||||||||||||
zfield_exprs = map(fieldnames(primal)) do fname | ||||||||||||||||||||||
:( | ||||||||||||||||||||||
if isdefined(primal, $(QuoteNode(fname))) | ||||||||||||||||||||||
zero_tangent(getfield(primal, $(QuoteNode(fname))), _cache) | ||||||||||||||||||||||
else | ||||||||||||||||||||||
# This is going to be potentially bad, but that's what they get for not giving us a primal | ||||||||||||||||||||||
# This will never me mutated inplace, rather it will alway be replaced with an actual value first | ||||||||||||||||||||||
ZeroTangent() | ||||||||||||||||||||||
end | ||||||||||||||||||||||
) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
return if has_mutable_tangent(primal) | ||||||||||||||||||||||
# This is a little complex because we need to support-self referential types | ||||||||||||||||||||||
# So we need to: | ||||||||||||||||||||||
# 1. create the tangent, | ||||||||||||||||||||||
# 2. put it in the cache | ||||||||||||||||||||||
# 3. Do all the calls to create the zeros for the fields giving them that cache) | ||||||||||||||||||||||
# 4. put those zeros into the object | ||||||||||||||||||||||
tangent_types = map(guess_zero_tangent_type, fieldtypes(primal)) | ||||||||||||||||||||||
is_defined_mask = Expr(:tuple, map(fieldnames(primal)) do fname | ||||||||||||||||||||||
:(isdefined(primal, $(QuoteNode(fname)))) | ||||||||||||||||||||||
end...) | ||||||||||||||||||||||
|
||||||||||||||||||||||
quote | ||||||||||||||||||||||
isnothing(_cache) && (_cache = IdDict()) | ||||||||||||||||||||||
found_tangent = get(_cache, primal, nothing) | ||||||||||||||||||||||
!isnothing(found_tangent) && return found_tangent | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Now we need to put into the cache a placeholder tangent so we can construct our fields using that cache | ||||||||||||||||||||||
# then put those fields into the placeholder | ||||||||||||||||||||||
tangent = $_MutableTangent(Val{$primal}(), $is_defined_mask, $tangent_types) | ||||||||||||||||||||||
_cache[primal] = tangent | ||||||||||||||||||||||
$( | ||||||||||||||||||||||
map(fieldnames(primal), zfield_exprs) do fname, fval_expr | ||||||||||||||||||||||
:(setproperty!(tangent, $(QuoteNode(fname)), $fval_expr)) | ||||||||||||||||||||||
end... | ||||||||||||||||||||||
) | ||||||||||||||||||||||
return tangent | ||||||||||||||||||||||
Comment on lines
+164
to
+169
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
end | ||||||||||||||||||||||
else | ||||||||||||||||||||||
:($Tangent{$primal}($(Expr(:parameters, Expr.(:kw, fieldnames(primal), zfield_exprs)...)))) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
||||||||||||||||||||||
end | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
function zero_tangent(primal::Tuple, _cache=nothing) | ||||||||||||||||||||||
return Tangent{typeof(primal)}(map(x -> zero_tangent(x, _cache), primal)...) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
function zero_tangent(x::Array{P,N}, _cache=nothing) where {P,N} | ||||||||||||||||||||||
if (isbitstype(P) || all(i -> isassigned(x, i), eachindex(x))) | ||||||||||||||||||||||
return map(zero_tangent, x) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Now we need to handle nonfully assigned arrays | ||||||||||||||||||||||
# see discussion at https://github.com/JuliaDiff/ChainRulesCore.jl/pull/626#discussion_r1345235265 | ||||||||||||||||||||||
y = Array{guess_zero_tangent_type(P),N}(undef, size(x)...) | ||||||||||||||||||||||
@inbounds for n in eachindex(y) | ||||||||||||||||||||||
if isassigned(x, n) | ||||||||||||||||||||||
y[n] = zero_tangent(x[n], _cache) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
end | ||||||||||||||||||||||
return y | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Sad heauristic methods | ||||||||||||||||||||||
#guess_zero_tangent_type(::Type{T}) where {T<:Number} = T | ||||||||||||||||||||||
#guess_zero_tangent_type(::Type{T}) where {T<:Integer} = typeof(float(zero(T))) | ||||||||||||||||||||||
function guess_zero_tangent_type(::Type{<:Array{T,N}}) where {T,N} | ||||||||||||||||||||||
return Array{guess_zero_tangent_type(T),N} | ||||||||||||||||||||||
end | ||||||||||||||||||||||
|
||||||||||||||||||||||
# The following will fall back to `Any` if it is hard to infer | ||||||||||||||||||||||
function guess_zero_tangent_type(::Type{T}) where {T} | ||||||||||||||||||||||
return Core.Compiler.return_type(zero_tangent, Tuple{T}) | ||||||||||||||||||||||
end | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶