From 675fecac6d1c90e083ab1173374fce3b2e8c6e13 Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas <accounts@chrisrackauckas.com>
Date: Wed, 22 Nov 2023 10:46:37 -0500
Subject: [PATCH 01/10] Add Tridiagonal construction rule

---
 src/rulesets/LinearAlgebra/structured.jl | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl
index 245335774..705c372da 100644
--- a/src/rulesets/LinearAlgebra/structured.jl
+++ b/src/rulesets/LinearAlgebra/structured.jl
@@ -268,3 +268,16 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular})
     end
     return y, logdet_pullback
 end
+
+#####
+##### Tridiagonal
+#####
+
+function rrule(::Type{Tridiagonal}, dl, d, du)
+    y = Tridiagonal(dl, d, du)
+    @views function ∇Tridiagonal(∂y)
+        return (NoTangent(), diag(∂y[2:end, 1:(end - 1)]), diag(∂y),
+            diag(∂y[1:(end - 1), 2:end]))
+    end
+    return y, ∇Tridiagonal
+end

From 3459acb1b02f31a663299b8f28f103ef62d2c110 Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas <accounts@chrisrackauckas.com>
Date: Wed, 22 Nov 2023 10:48:58 -0500
Subject: [PATCH 02/10] Update structured.jl

---
 test/rulesets/LinearAlgebra/structured.jl | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl
index 1b83cc394..46f089dbf 100644
--- a/test/rulesets/LinearAlgebra/structured.jl
+++ b/test/rulesets/LinearAlgebra/structured.jl
@@ -161,4 +161,9 @@
             end
         end
     end
+
+    @testset "Tridiagonal" begin
+        res, pb = rrule(Tridiagonal, [1, 4], [2, 3, 4], [5, 3])
+        @test pb(10*res) == (NoTangent(), [10, 40], [20, 30, 40], [50, 30])
+    end
 end

From a275f1ff33a6ef20613803e6475750e5d93fbdd2 Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas <accounts@chrisrackauckas.com>
Date: Wed, 22 Nov 2023 10:49:18 -0500
Subject: [PATCH 03/10] Update src/rulesets/LinearAlgebra/structured.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
---
 src/rulesets/LinearAlgebra/structured.jl | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl
index 705c372da..d2ecdaba9 100644
--- a/src/rulesets/LinearAlgebra/structured.jl
+++ b/src/rulesets/LinearAlgebra/structured.jl
@@ -276,8 +276,12 @@ end
 function rrule(::Type{Tridiagonal}, dl, d, du)
     y = Tridiagonal(dl, d, du)
     @views function ∇Tridiagonal(∂y)
-        return (NoTangent(), diag(∂y[2:end, 1:(end - 1)]), diag(∂y),
-            diag(∂y[1:(end - 1), 2:end]))
+        return (
+            NoTangent(),
+            diag(∂y[2:end, 1:(end - 1)]),
+            diag(∂y),
+            diag(∂y[1:(end - 1), 2:end]),
+        )
     end
     return y, ∇Tridiagonal
 end

From e823ed8c92d2054df5ab400290dd431455abd637 Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas <accounts@chrisrackauckas.com>
Date: Wed, 22 Nov 2023 10:53:30 -0500
Subject: [PATCH 04/10] Update test/rulesets/LinearAlgebra/structured.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
---
 test/rulesets/LinearAlgebra/structured.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl
index 46f089dbf..b64152204 100644
--- a/test/rulesets/LinearAlgebra/structured.jl
+++ b/test/rulesets/LinearAlgebra/structured.jl
@@ -164,6 +164,6 @@
 
     @testset "Tridiagonal" begin
         res, pb = rrule(Tridiagonal, [1, 4], [2, 3, 4], [5, 3])
-        @test pb(10*res) == (NoTangent(), [10, 40], [20, 30, 40], [50, 30])
+        @test pb(10 * res) == (NoTangent(), [10, 40], [20, 30, 40], [50, 30])
     end
 end

From 9ba86d8052b532e8e021f446403a63d4d0959612 Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas <accounts@chrisrackauckas.com>
Date: Thu, 23 Nov 2023 07:41:01 -0500
Subject: [PATCH 05/10] Update test/rulesets/LinearAlgebra/structured.jl

Co-authored-by: Frames White <oxinabox@ucc.asn.au>
---
 test/rulesets/LinearAlgebra/structured.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl
index b64152204..3c0c7f92f 100644
--- a/test/rulesets/LinearAlgebra/structured.jl
+++ b/test/rulesets/LinearAlgebra/structured.jl
@@ -163,7 +163,7 @@
     end
 
     @testset "Tridiagonal" begin
-        res, pb = rrule(Tridiagonal, [1, 4], [2, 3, 4], [5, 3])
+        test_rrule(Tridiagonal, [1.0, 4.0], [2.0, 3.0, 4.0], [5.0, 3.0])
         @test pb(10 * res) == (NoTangent(), [10, 40], [20, 30, 40], [50, 30])
     end
 end

From ee952df7f1092ec632fbcbfefacd2e2b24961f19 Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas <accounts@chrisrackauckas.com>
Date: Thu, 23 Nov 2023 07:43:34 -0500
Subject: [PATCH 06/10] Update src/rulesets/LinearAlgebra/structured.jl

---
 src/rulesets/LinearAlgebra/structured.jl | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl
index d2ecdaba9..f2d21a274 100644
--- a/src/rulesets/LinearAlgebra/structured.jl
+++ b/src/rulesets/LinearAlgebra/structured.jl
@@ -278,9 +278,9 @@ function rrule(::Type{Tridiagonal}, dl, d, du)
     @views function ∇Tridiagonal(∂y)
         return (
             NoTangent(),
-            diag(∂y[2:end, 1:(end - 1)]),
+            diag(∂y, -1),
             diag(∂y),
-            diag(∂y[1:(end - 1), 2:end]),
+            diag(∂y, 1),
         )
     end
     return y, ∇Tridiagonal

From 58957361ecfcf78488c831729cb66b26a473b893 Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas <accounts@chrisrackauckas.com>
Date: Thu, 23 Nov 2023 09:34:05 -0500
Subject: [PATCH 07/10] Update src/rulesets/LinearAlgebra/structured.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
---
 src/rulesets/LinearAlgebra/structured.jl | 7 +------
 1 file changed, 1 insertion(+), 6 deletions(-)

diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl
index f2d21a274..fa96dd126 100644
--- a/src/rulesets/LinearAlgebra/structured.jl
+++ b/src/rulesets/LinearAlgebra/structured.jl
@@ -276,12 +276,7 @@ end
 function rrule(::Type{Tridiagonal}, dl, d, du)
     y = Tridiagonal(dl, d, du)
     @views function ∇Tridiagonal(∂y)
-        return (
-            NoTangent(),
-            diag(∂y, -1),
-            diag(∂y),
-            diag(∂y, 1),
-        )
+        return (NoTangent(), diag(∂y, -1), diag(∂y), diag(∂y, 1))
     end
     return y, ∇Tridiagonal
 end

From dfbd363302a7bbeceb9c59078984d3c36b131d5d Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas <accounts@chrisrackauckas.com>
Date: Thu, 23 Nov 2023 09:34:14 -0500
Subject: [PATCH 08/10] Update test/rulesets/LinearAlgebra/structured.jl

Co-authored-by: Frames White <oxinabox@ucc.asn.au>
---
 test/rulesets/LinearAlgebra/structured.jl | 1 -
 1 file changed, 1 deletion(-)

diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl
index 3c0c7f92f..c46460f46 100644
--- a/test/rulesets/LinearAlgebra/structured.jl
+++ b/test/rulesets/LinearAlgebra/structured.jl
@@ -164,6 +164,5 @@
 
     @testset "Tridiagonal" begin
         test_rrule(Tridiagonal, [1.0, 4.0], [2.0, 3.0, 4.0], [5.0, 3.0])
-        @test pb(10 * res) == (NoTangent(), [10, 40], [20, 30, 40], [50, 30])
     end
 end

From 938623b9496c0a3f38a9d1479ed8a1c477a86620 Mon Sep 17 00:00:00 2001
From: Frames White <me@oxinabox.net>
Date: Mon, 20 May 2024 21:48:27 +0800
Subject: [PATCH 09/10] unthunk input to Tridiagonal_pullback

---
 src/rulesets/LinearAlgebra/structured.jl | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl
index fa96dd126..722e37791 100644
--- a/src/rulesets/LinearAlgebra/structured.jl
+++ b/src/rulesets/LinearAlgebra/structured.jl
@@ -275,8 +275,9 @@ end
 
 function rrule(::Type{Tridiagonal}, dl, d, du)
     y = Tridiagonal(dl, d, du)
-    @views function ∇Tridiagonal(∂y)
+    @views function Tridiagonal_pullback(ȳ)
+        ∂y = unthunk(ȳ)
         return (NoTangent(), diag(∂y, -1), diag(∂y), diag(∂y, 1))
     end
-    return y, ∇Tridiagonal
+    return y, Tridiagonal_pullback
 end

From ce288b9c571f636b73fd8e128543c9cde979edb1 Mon Sep 17 00:00:00 2001
From: Frames White <me@oxinabox.net>
Date: Mon, 20 May 2024 21:50:34 +0800
Subject: [PATCH 10/10] remove unneded veiws macrod

---
 src/rulesets/LinearAlgebra/structured.jl | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl
index 722e37791..e48416739 100644
--- a/src/rulesets/LinearAlgebra/structured.jl
+++ b/src/rulesets/LinearAlgebra/structured.jl
@@ -275,7 +275,7 @@ end
 
 function rrule(::Type{Tridiagonal}, dl, d, du)
     y = Tridiagonal(dl, d, du)
-    @views function Tridiagonal_pullback(ȳ)
+    function Tridiagonal_pullback(ȳ)
         ∂y = unthunk(ȳ)
         return (NoTangent(), diag(∂y, -1), diag(∂y), diag(∂y, 1))
     end