@@ -102,61 +102,102 @@ end
102
102
end
103
103
end
104
104
end
105
- @testset " svd" begin
106
- for n in [4 , 6 , 10 ], m in [3 , 5 , 9 ]
107
- @testset " ($n x $m ) svd" begin
108
- X = randn (n, m)
109
- test_rrule (svd, X; atol= 1e-6 , rtol= 1e-6 )
110
- end
111
- end
112
105
113
- for n in [4 , 6 , 10 ], m in [3 , 5 , 10 ]
114
- @testset " ($n x $m ) getproperty" begin
115
- X = randn (n, m)
116
- F = svd (X)
117
- rand_adj = adjoint (rand (reverse (size (F. V))... ))
106
+ @testset " singular value decomposition" begin
107
+ @testset " svd" begin
108
+ for n in [4 , 6 , 10 ], m in [3 , 5 , 9 ]
109
+ @testset " ($n x $m ) svd" begin
110
+ X = randn (n, m)
111
+ test_rrule (svd, X; atol= 1e-6 , rtol= 1e-6 )
112
+ end
113
+ end
118
114
119
- test_rrule (getproperty, F, :U ; check_inferred= false )
120
- test_rrule (getproperty, F, :S ; check_inferred= false )
121
- test_rrule (getproperty, F, :Vt ; check_inferred= false )
122
- test_rrule (getproperty, F, :V ; check_inferred= false , output_tangent= rand_adj)
115
+ for n in [4 , 6 , 10 ], m in [3 , 5 , 10 ]
116
+ @testset " ($n x $m ) getproperty" begin
117
+ X = randn (n, m)
118
+ F = svd (X)
119
+ rand_adj = adjoint (rand (reverse (size (F. V))... ))
120
+
121
+ test_rrule (getproperty, F, :U ; check_inferred= false )
122
+ test_rrule (getproperty, F, :S ; check_inferred= false )
123
+ test_rrule (getproperty, F, :Vt ; check_inferred= false )
124
+ test_rrule (
125
+ getproperty, F, :V ; check_inferred= false , output_tangent= rand_adj
126
+ )
127
+ end
123
128
end
124
- end
125
129
126
- @testset " Thunked inputs" begin
127
- X = randn (4 , 3 )
128
- F, dX_pullback = rrule (svd, X)
129
- for p in [:U , :S , :V , :Vt ]
130
- Y, dF_pullback = rrule (getproperty, F, p)
131
- Ȳ = randn (size (Y)... )
130
+ @testset " Thunked inputs" begin
131
+ X = randn (4 , 3 )
132
+ F, dX_pullback = rrule (svd, X)
133
+ for p in [:U , :S , :V , :Vt ]
134
+ Y, dF_pullback = rrule (getproperty, F, p)
135
+ Ȳ = randn (size (Y)... )
136
+
137
+ _, dF_unthunked, _ = dF_pullback (Ȳ)
132
138
133
- _, dF_unthunked, _ = dF_pullback (Ȳ)
139
+ # helper to let us check how things are stored.
140
+ p_access = p == :V ? :Vt : p
141
+ backing_field (c, p) = getproperty (ChainRulesCore. backing (c), p_access)
142
+ @assert ! (backing_field (dF_unthunked, p) isa AbstractThunk)
134
143
135
- # helper to let us check how things are stored.
136
- p_access = p == :V ? :Vt : p
137
- backing_field (c, p) = getproperty (ChainRulesCore. backing (c), p_access)
138
- @assert ! (backing_field (dF_unthunked, p) isa AbstractThunk)
144
+ dF_thunked = map (f -> Thunk (() -> f), dF_unthunked)
145
+ @assert backing_field (dF_thunked, p) isa AbstractThunk
146
+
147
+ dself_thunked, dX_thunked = dX_pullback (dF_thunked)
148
+ dself_unthunked, dX_unthunked = dX_pullback (dF_unthunked)
149
+ @test dself_thunked == dself_unthunked
150
+ @test dX_thunked == dX_unthunked
151
+ end
152
+ end
139
153
140
- dF_thunked = map (f-> Thunk (()-> f), dF_unthunked)
141
- @assert backing_field (dF_thunked, p) isa AbstractThunk
154
+ @testset " Helper functions" begin
155
+ X = randn (10 , 10 )
156
+ Y = randn (10 , 10 )
157
+ @test ChainRules. _mulsubtrans!! (copy (X), Y) ≈ Y .* (X - X' )
158
+ @test ChainRules. _eyesubx! (copy (X)) ≈ I - X
142
159
143
- dself_thunked, dX_thunked = dX_pullback (dF_thunked )
144
- dself_unthunked, dX_unthunked = dX_pullback (dF_unthunked )
145
- @test dself_thunked == dself_unthunked
146
- @test dX_thunked == dX_unthunked
160
+ Z = randn (Float32, 10 , 10 )
161
+ result = ChainRules . _mulsubtrans!! ( copy (Z), Y )
162
+ @test result ≈ Y .* (Z - Z ' )
163
+ @test eltype (result) == Float64
147
164
end
148
165
end
149
166
150
- @testset " Helper functions" begin
151
- X = randn (10 , 10 )
152
- Y = randn (10 , 10 )
153
- @test ChainRules. _mulsubtrans!! (copy (X), Y) ≈ Y .* (X - X' )
154
- @test ChainRules. _eyesubx! (copy (X)) ≈ I - X
167
+ @testset " svdvals" begin
168
+ for n in [4 , 6 , 10 ]
169
+ for m in [3 , 5 , 9 ]
170
+ @testset " ($n x $m ) svdvals" begin
171
+ X = randn (n, m)
172
+ test_rrule (svdvals, X; atol= 1e-6 , rtol= 1e-6 )
173
+ end
174
+ end
175
+
176
+ @testset " rrule for svdvals(::$SymHerm {$T }) ($n x $n , uplo=$uplo )" for SymHerm in
177
+ (
178
+ Symmetric, Hermitian
179
+ ),
180
+ T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)),
181
+ uplo in (:L , :U )
182
+
183
+ A, ΔS = randn (T, n, n), randn (n)
184
+ symA = SymHerm (A, uplo)
155
185
156
- Z = randn (Float32, 10 , 10 )
157
- result = ChainRules. _mulsubtrans!! (copy (Z), Y)
158
- @test result ≈ Y .* (Z - Z' )
159
- @test eltype (result) == Float64
186
+ S = svdvals (symA)
187
+ S_ad, back = @inferred rrule (svdvals, symA)
188
+ @test S_ad ≈ S # inexact because rrule uses svd not svdvals
189
+ ∂self, ∂symA = @inferred back (ΔS)
190
+ @test ∂self === NoTangent ()
191
+ @test ∂symA isa typeof (symA)
192
+ @test ∂symA. uplo == symA. uplo
193
+
194
+ # pull the cotangent back to A to test against finite differences
195
+ ∂A = rrule (SymHerm, A, uplo)[2 ](∂symA)[2 ]
196
+ @test ∂A ≈ j′vp (_fdm, A -> svdvals (SymHerm (A, uplo)), ΔS, A)[1 ]
197
+
198
+ @test @inferred (back (ZeroTangent ())) == (NoTangent (), ZeroTangent ())
199
+ end
200
+ end
160
201
end
161
202
end
162
203
0 commit comments