Skip to content

Commit

Permalink
ChainRulesCore v1.0 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Sep 1, 2021
1 parent a947977 commit 9818994
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
ArgCheck = "1, 2"
ArraysOfArrays = "0.5"
ChainRulesCore = "0.9.44, 0.10"
ChainRulesCore = "0.9.44, 0.10, 1"
Distributions = "0.23, 0.24, 0.25"
ElasticArrays = "1.0"
FillArrays = "0.7, 0.8, 0.9, 0.10, 0.11, 0.12"
Expand Down
2 changes: 1 addition & 1 deletion src/ValueShapes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import TypedTables
# Long-term, ChainRulesCore should be sufficient:
import ZygoteRules

using ChainRulesCore: AbstractTangent, Tangent, NoTangent, ZeroTangent
using ChainRulesCore: AbstractTangent, Tangent, NoTangent, ZeroTangent, AbstractThunk, unthunk

include("value_shape.jl")
include("value_accessor.jl")
Expand Down
11 changes: 6 additions & 5 deletions src/named_tuple_shape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ end
# Zygote will currently ignore this, see Zygote.jl issue #811:
function ChainRulesCore.rrule(::typeof(Base.getindex), x::ShapedAsNT)
result = x[]
shapedasnt_getindex_pullback(ΔΩ::NamedTuple) = (NoTangent(), _shaped_nt_ΔΩ(ΔΩ, result, x))
shapedasnt_getindex_pullback(ΔΩ) = (NoTangent(), _shaped_nt_ΔΩ(unthunk(ΔΩ), result, x))
return result, shapedasnt_getindex_pullback
end
#
Expand All @@ -361,17 +361,18 @@ end

function ChainRulesCore.rrule(::Type{ShapedAsNT}, A::AbstractVector{<:Real}, vs::NamedTupleShape{names}) where names
result = ShapedAsNT(A, vs)
function shapedasnt_pullback(ΔΩ::Union{ShapedAsNT{<:NamedTuple{names}},NamedTuple{names}})
function shapedasnt_pullback_impl(ΔΩ::Union{ShapedAsNT{<:NamedTuple{names}},NamedTuple{names}})
(NoTangent(), unshaped(ΔΩ, gradient_shape(vs)), nothing)
end
function shapedasnt_pullback(ΔΩ_c::Tangent{Any,<:NamedTuple{names}})
function shapedasnt_pullback_impl(ΔΩ_c::Tangent{Any,<:NamedTuple{names}})
ΔΩ = NamedTuple{names}((ΔΩ_c...,))
shapedasnt_pullback(ΔΩ)
shapedasnt_pullback_impl(ΔΩ)
end
function shapedasnt_pullback(ΔΩ_c::Tangent{Any,<:NamedTuple{(:__internal_data, :__internal_valshape)}})
function shapedasnt_pullback_impl(ΔΩ_c::Tangent{Any,<:NamedTuple{(:__internal_data, :__internal_valshape)}})
@assert ΔΩ_c.__internal_valshape == NoTangent() || ΔΩ_c.__internal_valshape == ZeroTangent()
(NoTangent(), ΔΩ_c.__internal_data, nothing)
end
shapedasnt_pullback(ΔΩ) = shapedasnt_pullback_impl(unthunk(ΔΩ))
return result, shapedasnt_pullback
end

Expand Down

0 comments on commit 9818994

Please sign in to comment.