Skip to content

Commit

Permalink
Improve ShapedAsNTArray rrules
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Jun 8, 2022
1 parent fa114a0 commit 2b78ee6
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/named_tuple_shape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,20 @@ end
const _AnySNTArray{names} = ShapedAsNTArray{<:Union{NamedTuple{names},ShapedAsNT{names}}}


# For accumulation during automatic differentiation:
function Base.:(+)(A::_AnySNTArray{names}, B::_AnySNTArray{names}) where names
@argcheck elshape(A) == elshape(B)
ShapedAsNTArray(_data(A) + _data(B), elshape(A))
end

# For accumulation during automatic differentiation:
function ChainRulesCore.add!!(A::_AnySNTArray{names}, B::_AnySNTArray{names}) where names
@argcheck elshape(A) == elshape(B)
ChainRulesCore.add!!(_data(A), _data(B))
return A
end


function ChainRulesCore.Tangent(X::T, unshaped_dX::AbstractArray{<:AbstractVector{<:Real}}) where {T<:ShapedAsNTArray}
vs = elshape(X)
gs = gradient_shape(vs)
Expand Down Expand Up @@ -789,6 +803,11 @@ function _write_snta_col!(data::AbstractArray{<:AbstractVector{<:Real}}, va::Val
B = view.(data, Ref(va))
B .= A
end
function _write_snta_col!(data::ArrayOfSimilarVectors{<:Real}, va::ValueAccessor, A::_ZeroLike)
flat_data = flatview(data)
idxs = view_idxs(axes(flat_data, 1), va)
fill!(view(flat_data, idxs, :), zero(eltype(flat_data)))
end
_write_snta_col!(data::AbstractArray{<:AbstractVector{<:Real}}, va::ConstAccessor, A) = nothing

function _tablecols_tangent(X::_AnySNTArray, dY::NamedTuple{names}) where names
Expand Down

0 comments on commit 2b78ee6

Please sign in to comment.