diff --git a/Project.toml b/Project.toml index db4a4ffa..1f3c1df3 100644 --- a/Project.toml +++ b/Project.toml @@ -20,7 +20,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] ArgCheck = "1, 2" ArraysOfArrays = "0.5" -ChainRulesCore = "0.9.44, 0.10" +ChainRulesCore = "0.9.44, 0.10, 1" Distributions = "0.23, 0.24, 0.25" ElasticArrays = "1.0" FillArrays = "0.7, 0.8, 0.9, 0.10, 0.11, 0.12" diff --git a/src/ValueShapes.jl b/src/ValueShapes.jl index 7b5e191f..3c469ec5 100644 --- a/src/ValueShapes.jl +++ b/src/ValueShapes.jl @@ -27,7 +27,7 @@ import TypedTables # Long-term, ChainRulesCore should be sufficient: import ZygoteRules -using ChainRulesCore: AbstractTangent, Tangent, NoTangent, ZeroTangent +using ChainRulesCore: AbstractTangent, Tangent, NoTangent, ZeroTangent, AbstractThunk, unthunk include("value_shape.jl") include("value_accessor.jl") diff --git a/src/named_tuple_shape.jl b/src/named_tuple_shape.jl index 7be84215..dee79687 100644 --- a/src/named_tuple_shape.jl +++ b/src/named_tuple_shape.jl @@ -345,7 +345,7 @@ end # Zygote will currently ignore this, see Zygote.jl issue #811: function ChainRulesCore.rrule(::typeof(Base.getindex), x::ShapedAsNT) result = x[] - shapedasnt_getindex_pullback(ΔΩ::NamedTuple) = (NoTangent(), _shaped_nt_ΔΩ(ΔΩ, result, x)) + shapedasnt_getindex_pullback(ΔΩ) = (NoTangent(), _shaped_nt_ΔΩ(unthunk(ΔΩ), result, x)) return result, shapedasnt_getindex_pullback end # @@ -361,17 +361,18 @@ end function ChainRulesCore.rrule(::Type{ShapedAsNT}, A::AbstractVector{<:Real}, vs::NamedTupleShape{names}) where names result = ShapedAsNT(A, vs) - function shapedasnt_pullback(ΔΩ::Union{ShapedAsNT{<:NamedTuple{names}},NamedTuple{names}}) + function shapedasnt_pullback_impl(ΔΩ::Union{ShapedAsNT{<:NamedTuple{names}},NamedTuple{names}}) (NoTangent(), unshaped(ΔΩ, gradient_shape(vs)), nothing) end - function shapedasnt_pullback(ΔΩ_c::Tangent{Any,<:NamedTuple{names}}) + function shapedasnt_pullback_impl(ΔΩ_c::Tangent{Any,<:NamedTuple{names}}) ΔΩ = NamedTuple{names}((ΔΩ_c...,)) - shapedasnt_pullback(ΔΩ) + shapedasnt_pullback_impl(ΔΩ) end - function shapedasnt_pullback(ΔΩ_c::Tangent{Any,<:NamedTuple{(:__internal_data, :__internal_valshape)}}) + function shapedasnt_pullback_impl(ΔΩ_c::Tangent{Any,<:NamedTuple{(:__internal_data, :__internal_valshape)}}) @assert ΔΩ_c.__internal_valshape == NoTangent() || ΔΩ_c.__internal_valshape == ZeroTangent() (NoTangent(), ΔΩ_c.__internal_data, nothing) end + shapedasnt_pullback(ΔΩ) = shapedasnt_pullback_impl(unthunk(ΔΩ)) return result, shapedasnt_pullback end