diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index f80a37048..d947567e5 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -261,4 +261,17 @@ end Val(3); check_inferred=(VERSION >= v"1.7"), ) + + # eachslice: Make sure pulling back an array of thunks unthunks them and does not return all zeros. + x = ones(Float32, 3) + Δ = ones(Float32, 1) + _, norm_back = ChainRules.rrule(norm, x) + dx = norm_back(Δ)[2] + @test dx isa AbstractThunk + + x = ones(Float32, 3, 1) + _, eachcol_back = ChainRules.rrule(eachcol, x) + Δ2 = [dx] + dx2 = eachcol_back(Δ2)[2] + @test all(dx2 .≉ 0f0) end