Skip to content

Commit 0f6752f

Browse files
committed
Generalize rrule for svdvals
1 parent f9f0722 commit 0f6752f

File tree

4 files changed

+100
-87
lines changed

4 files changed

+100
-87
lines changed

src/rulesets/LinearAlgebra/factorization.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,22 @@ function svd_rev(USV::SVD, Ū, s̄, V̄)
276276
return Ā
277277
end
278278

279+
#####
280+
##### `svdvals`
281+
#####
282+
283+
function rrule(::typeof(svdvals), A::AbstractMatrix{<:Number})
284+
F = svd(A)
285+
U = F.U
286+
Vt = F.Vt
287+
project_A = ProjectTo(A)
288+
function svdvals_pullback(s̄)
289+
=isa AbstractZero ?: Diagonal(unthunk(s̄))
290+
(NoTangent(), project_A(U ** Vt))
291+
end
292+
return F.S, svdvals_pullback
293+
end
294+
279295
#####
280296
##### `eigen`
281297
#####

src/rulesets/LinearAlgebra/symmetric.jl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -277,28 +277,6 @@ function _svd_eigvals_sign!(c, U, V)
277277
return c
278278
end
279279

280-
#####
281-
##### `svdvals`
282-
#####
283-
284-
# NOTE: rrule defined because `svdvals` calls mutating `svdvals!` internally.
285-
# can be removed when mutation is supported by reverse-mode AD packages
286-
function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix})
287-
λ, back = rrule(eigvals, A)
288-
S = abs.(λ)
289-
p = sortperm(S; rev=true)
290-
permute!(S, p)
291-
function svdvals_pullback(ΔS)
292-
∂λ = real.(ΔS)
293-
invpermute!(∂λ, p)
294-
∂λ .*= sign.(λ)
295-
_, ∂A = back(∂λ)
296-
return NoTangent(), unthunk(∂A)
297-
end
298-
svdvals_pullback(ΔS::AbstractZero) = (NoTangent(), ΔS)
299-
return S, svdvals_pullback
300-
end
301-
302280
#####
303281
##### matrix functions
304282
#####

test/rulesets/LinearAlgebra/factorization.jl

Lines changed: 84 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -102,61 +102,102 @@ end
102102
end
103103
end
104104
end
105-
@testset "svd" begin
106-
for n in [4, 6, 10], m in [3, 5, 9]
107-
@testset "($n x $m) svd" begin
108-
X = randn(n, m)
109-
test_rrule(svd, X; atol=1e-6, rtol=1e-6)
110-
end
111-
end
112105

113-
for n in [4, 6, 10], m in [3, 5, 10]
114-
@testset "($n x $m) getproperty" begin
115-
X = randn(n, m)
116-
F = svd(X)
117-
rand_adj = adjoint(rand(reverse(size(F.V))...))
106+
@testset "singular value decomposition" begin
107+
@testset "svd" begin
108+
for n in [4, 6, 10], m in [3, 5, 9]
109+
@testset "($n x $m) svd" begin
110+
X = randn(n, m)
111+
test_rrule(svd, X; atol=1e-6, rtol=1e-6)
112+
end
113+
end
118114

119-
test_rrule(getproperty, F, :U; check_inferred=false)
120-
test_rrule(getproperty, F, :S; check_inferred=false)
121-
test_rrule(getproperty, F, :Vt; check_inferred=false)
122-
test_rrule(getproperty, F, :V; check_inferred=false, output_tangent=rand_adj)
115+
for n in [4, 6, 10], m in [3, 5, 10]
116+
@testset "($n x $m) getproperty" begin
117+
X = randn(n, m)
118+
F = svd(X)
119+
rand_adj = adjoint(rand(reverse(size(F.V))...))
120+
121+
test_rrule(getproperty, F, :U; check_inferred=false)
122+
test_rrule(getproperty, F, :S; check_inferred=false)
123+
test_rrule(getproperty, F, :Vt; check_inferred=false)
124+
test_rrule(
125+
getproperty, F, :V; check_inferred=false, output_tangent=rand_adj
126+
)
127+
end
123128
end
124-
end
125129

126-
@testset "Thunked inputs" begin
127-
X = randn(4, 3)
128-
F, dX_pullback = rrule(svd, X)
129-
for p in [:U, :S, :V, :Vt]
130-
Y, dF_pullback = rrule(getproperty, F, p)
131-
= randn(size(Y)...)
130+
@testset "Thunked inputs" begin
131+
X = randn(4, 3)
132+
F, dX_pullback = rrule(svd, X)
133+
for p in [:U, :S, :V, :Vt]
134+
Y, dF_pullback = rrule(getproperty, F, p)
135+
= randn(size(Y)...)
136+
137+
_, dF_unthunked, _ = dF_pullback(Ȳ)
132138

133-
_, dF_unthunked, _ = dF_pullback(Ȳ)
139+
# helper to let us check how things are stored.
140+
p_access = p == :V ? :Vt : p
141+
backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p_access)
142+
@assert !(backing_field(dF_unthunked, p) isa AbstractThunk)
134143

135-
# helper to let us check how things are stored.
136-
p_access = p == :V ? :Vt : p
137-
backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p_access)
138-
@assert !(backing_field(dF_unthunked, p) isa AbstractThunk)
144+
dF_thunked = map(f -> Thunk(() -> f), dF_unthunked)
145+
@assert backing_field(dF_thunked, p) isa AbstractThunk
146+
147+
dself_thunked, dX_thunked = dX_pullback(dF_thunked)
148+
dself_unthunked, dX_unthunked = dX_pullback(dF_unthunked)
149+
@test dself_thunked == dself_unthunked
150+
@test dX_thunked == dX_unthunked
151+
end
152+
end
139153

140-
dF_thunked = map(f->Thunk(()->f), dF_unthunked)
141-
@assert backing_field(dF_thunked, p) isa AbstractThunk
154+
@testset "Helper functions" begin
155+
X = randn(10, 10)
156+
Y = randn(10, 10)
157+
@test ChainRules._mulsubtrans!!(copy(X), Y) Y .* (X - X')
158+
@test ChainRules._eyesubx!(copy(X)) I - X
142159

143-
dself_thunked, dX_thunked = dX_pullback(dF_thunked)
144-
dself_unthunked, dX_unthunked = dX_pullback(dF_unthunked)
145-
@test dself_thunked == dself_unthunked
146-
@test dX_thunked == dX_unthunked
160+
Z = randn(Float32, 10, 10)
161+
result = ChainRules._mulsubtrans!!(copy(Z), Y)
162+
@test result Y .* (Z - Z')
163+
@test eltype(result) == Float64
147164
end
148165
end
149166

150-
@testset "Helper functions" begin
151-
X = randn(10, 10)
152-
Y = randn(10, 10)
153-
@test ChainRules._mulsubtrans!!(copy(X), Y) Y .* (X - X')
154-
@test ChainRules._eyesubx!(copy(X)) I - X
167+
@testset "svdvals" begin
168+
for n in [4, 6, 10]
169+
for m in [3, 5, 9]
170+
@testset "($n x $m) svdvals" begin
171+
X = randn(n, m)
172+
test_rrule(svdvals, X; atol=1e-6, rtol=1e-6)
173+
end
174+
end
175+
176+
@testset "rrule for svdvals(::$SymHerm{$T}) ($n x $n, uplo=$uplo)" for SymHerm in
177+
(
178+
Symmetric, Hermitian
179+
),
180+
T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)),
181+
uplo in (:L, :U)
182+
183+
A, ΔS = randn(T, n, n), randn(n)
184+
symA = SymHerm(A, uplo)
155185

156-
Z = randn(Float32, 10, 10)
157-
result = ChainRules._mulsubtrans!!(copy(Z), Y)
158-
@test result Y .* (Z - Z')
159-
@test eltype(result) == Float64
186+
S = svdvals(symA)
187+
S_ad, back = @inferred rrule(svdvals, symA)
188+
@test S_ad S # inexact because rrule uses svd not svdvals
189+
∂self, ∂symA = @inferred back(ΔS)
190+
@test ∂self === NoTangent()
191+
@test ∂symA isa typeof(symA)
192+
@test ∂symA.uplo == symA.uplo
193+
194+
# pull the cotangent back to A to test against finite differences
195+
∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2]
196+
@test ∂A j′vp(_fdm, A -> svdvals(SymHerm(A, uplo)), ΔS, A)[1]
197+
198+
@test @inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent())
199+
end
200+
end
160201
end
161202
end
162203

test/rulesets/LinearAlgebra/symmetric.jl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -275,28 +275,6 @@
275275
@test @maybe_inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent())
276276
@test @maybe_inferred(back(CT())) == (NoTangent(), ZeroTangent())
277277
end
278-
279-
@testset "rrule for svdvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian),
280-
T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)),
281-
uplo in (:L, :U)
282-
283-
A, ΔS = randn(T, n, n), randn(n)
284-
symA = SymHerm(A, uplo)
285-
286-
S = svdvals(symA)
287-
S_ad, back = @maybe_inferred rrule(svdvals, symA)
288-
@test S_ad S # inexact because rrule uses svd not svdvals
289-
∂self, ∂symA = @maybe_inferred back(ΔS)
290-
@test ∂self === NoTangent()
291-
@test ∂symA isa typeof(symA)
292-
@test ∂symA.uplo == symA.uplo
293-
294-
# pull the cotangent back to A to test against finite differences
295-
∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2]
296-
@test ∂A j′vp(_fdm, A -> svdvals(SymHerm(A, uplo)), ΔS, A)[1]
297-
298-
@test @maybe_inferred(back(ZeroTangent())) == (NoTangent(), ZeroTangent())
299-
end
300278
end
301279

302280
@testset "Symmetric/Hermitian matrix functions" begin

0 commit comments

Comments
 (0)