diff --git a/src/named_tuple_shape.jl b/src/named_tuple_shape.jl index 701ab3e5..072d8d94 100644 --- a/src/named_tuple_shape.jl +++ b/src/named_tuple_shape.jl @@ -752,6 +752,20 @@ end const _AnySNTArray{names} = ShapedAsNTArray{<:Union{NamedTuple{names},ShapedAsNT{names}}} +# For accumulation during automatic differentiation: +function Base.:(+)(A::_AnySNTArray{names}, B::_AnySNTArray{names}) where names + @argcheck elshape(A) == elshape(B) + ShapedAsNTArray(_data(A) + _data(B), elshape(A)) +end + +# For accumulation during automatic differentiation: +function ChainRulesCore.add!!(A::_AnySNTArray{names}, B::_AnySNTArray{names}) where names + @argcheck elshape(A) == elshape(B) + ChainRulesCore.add!!(_data(A), _data(B)) + return A +end + + function ChainRulesCore.Tangent(X::T, unshaped_dX::AbstractArray{<:AbstractVector{<:Real}}) where {T<:ShapedAsNTArray} vs = elshape(X) gs = gradient_shape(vs) @@ -789,6 +803,11 @@ function _write_snta_col!(data::AbstractArray{<:AbstractVector{<:Real}}, va::Val B = view.(data, Ref(va)) B .= A end +function _write_snta_col!(data::ArrayOfSimilarVectors{<:Real}, va::ValueAccessor, A::_ZeroLike) + flat_data = flatview(data) + idxs = view_idxs(axes(flat_data, 1), va) + fill!(view(flat_data, idxs, :), zero(eltype(flat_data))) +end _write_snta_col!(data::AbstractArray{<:AbstractVector{<:Real}}, va::ConstAccessor, A) = nothing function _tablecols_tangent(X::_AnySNTArray, dY::NamedTuple{names}) where names