From ae339b693d8d48cae2def49a1a7b620eb6591817 Mon Sep 17 00:00:00 2001
From: Frames White <oxinabox@ucc.asn.au>
Date: Fri, 9 Feb 2024 18:50:35 +0800
Subject: [PATCH 1/3] Use ProjectTo in multiarg +

---
 src/rulesets/Base/arraymath.jl | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl
index 078bb602a..89ed9fe8a 100644
--- a/src/rulesets/Base/arraymath.jl
+++ b/src/rulesets/Base/arraymath.jl
@@ -443,10 +443,10 @@ frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...)
 
 function rrule(::typeof(+), arrs::AbstractArray...)
     y = +(arrs...)
-    arr_axs = map(axes, arrs)
+    projs = map(ProjectTo, arrs)
     function add_pullback(dy_raw)
-        dy = unthunk(dy_raw)  # reshape will otherwise unthunk N times
-        return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...)
+        dy = unthunk(dy_raw)  # projs will otherwise unthunk N times
+        return (NoTangent(), map(proj -> proj(dy), projs)...)
     end
     return y, add_pullback
 end

From abf9e3757eb3ead01a519fdeeb6b32eb285d87f3 Mon Sep 17 00:00:00 2001
From: Frames White <me@oxinabox.net>
Date: Mon, 12 Feb 2024 14:33:36 +0800
Subject: [PATCH 2/3] Add tricky cases

---
 test/rulesets/Base/arraymath.jl | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl
index 2682f3b8a..7c76aa17a 100644
--- a/test/rulesets/Base/arraymath.jl
+++ b/test/rulesets/Base/arraymath.jl
@@ -216,5 +216,7 @@
         # rev
         @gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4))
         @gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1))
+        test_rrule(+, randn(3,3), Diagonal(randn(3)), randn(3,3,1))
+        test_rrule(+, randn(3,3), Diagonal(randn(3)), Symmetric(randn(3,3)))
     end
 end

From 5f078bfee88cc6c3da41fe91e939f2f7e8b68a36 Mon Sep 17 00:00:00 2001
From: Frames White <oxinabox@ucc.asn.au>
Date: Mon, 12 Feb 2024 14:36:01 +0800
Subject: [PATCH 3/3] formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
---
 test/rulesets/Base/arraymath.jl | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl
index 7c76aa17a..0a1416444 100644
--- a/test/rulesets/Base/arraymath.jl
+++ b/test/rulesets/Base/arraymath.jl
@@ -215,8 +215,8 @@
         @gpu test_frule(+, randn(2), randn(2), randn(2))
         # rev
         @gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4))
-        @gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1))
-        test_rrule(+, randn(3,3), Diagonal(randn(3)), randn(3,3,1))
-        test_rrule(+, randn(3,3), Diagonal(randn(3)), Symmetric(randn(3,3)))
+        @gpu test_rrule(+, randn(3), randn(3, 1), randn(3, 1, 1))
+        test_rrule(+, randn(3, 3), Diagonal(randn(3)), randn(3, 3, 1))
+        test_rrule(+, randn(3, 3), Diagonal(randn(3)), Symmetric(randn(3, 3)))
     end
 end