-
Notifications
You must be signed in to change notification settings - Fork 5
[WIP] Mooncake forward rules #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
kshyatt
wants to merge
3
commits into
main
Choose a base branch
from
ksh/mooncake_fwd
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
+575
−184
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
index da115a9..a807826 100644
--- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
+++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
@@ -26,15 +26,16 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu
return CoDual(Ac, dAc), copy_input_pb
end
# two-argument factorizations like LQ, QR, EIG
-for (f!, f, pb, pf, adj) in ((qr_full!, qr_full, qr_pullback!, qr_pushforward!, :dqr_adjoint),
- (qr_compact!, qr_compact, qr_pullback!, qr_pushforward!, :dqr_adjoint),
- (lq_full!, lq_full, lq_pullback!, lq_pushforward!, :dlq_adjoint),
- (lq_compact!, lq_compact, lq_pullback!, lq_pushforward!, :dlq_adjoint),
- (eig_full!, eig_full, eig_pullback!, eig_pushforward!, :deig_adjoint),
- (eigh_full!, eigh_full, eigh_pullback!, eigh_pushforward!, :deigh_adjoint),
- (left_polar!, left_polar, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint),
- (right_polar!, right_polar, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint),
- )
+for (f!, f, pb, pf, adj) in (
+ (qr_full!, qr_full, qr_pullback!, qr_pushforward!, :dqr_adjoint),
+ (qr_compact!, qr_compact, qr_pullback!, qr_pushforward!, :dqr_adjoint),
+ (lq_full!, lq_full, lq_pullback!, lq_pushforward!, :dlq_adjoint),
+ (lq_compact!, lq_compact, lq_pullback!, lq_pushforward!, :dlq_adjoint),
+ (eig_full!, eig_full, eig_pullback!, eig_pushforward!, :deig_adjoint),
+ (eigh_full!, eigh_full, eigh_pullback!, eigh_pushforward!, :deigh_adjoint),
+ (left_polar!, left_polar, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint),
+ (right_polar!, right_polar, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint),
+ )
@eval begin
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm}
@@ -82,21 +83,21 @@ for (f!, f, pb, pf, adj) in ((qr_full!, qr_full, qr_pullback!, qr_pushforward!,
end
function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
- args = Mooncake.primal(args_dargs)
- args = $f!(A, args, Mooncake.primal(alg_dalg))
+ args = Mooncake.primal(args_dargs)
+ args = $f!(A, args, Mooncake.primal(alg_dalg))
dargs = Mooncake.tangent(args_dargs)
- arg1, darg1 = arrayify(args[1], dargs[1])
- arg2, darg2 = arrayify(args[2], dargs[2])
+ arg1, darg1 = arrayify(args[1], dargs[1])
+ arg2, darg2 = arrayify(args[2], dargs[2])
darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2))
zero!(dA)
return args_dargs
end
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm})
- A, dA = arrayify(A_dA)
- args = $f(A, Mooncake.primal(alg_dalg))
- args_dargs = Mooncake.zero_dual(args)
- arg1, arg2 = args
- dargs = Mooncake.tangent(args_dargs)
+ A, dA = arrayify(A_dA)
+ args = $f(A, Mooncake.primal(alg_dalg))
+ args_dargs = Mooncake.zero_dual(args)
+ arg1, arg2 = args
+ dargs = Mooncake.tangent(args_dargs)
arg1, darg1 = arrayify(arg1, dargs[1])
arg2, darg2 = arrayify(arg2, dargs[2])
$pf(dA, A, (arg1, arg2), (darg1, darg2))
@@ -105,9 +106,10 @@ for (f!, f, pb, pf, adj) in ((qr_full!, qr_full, qr_pullback!, qr_pushforward!,
end
end
-for (f!, f, pb, pf, adj) in ((qr_null!, qr_null, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint),
- (lq_null!, lq_null, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint),
- )
+for (f!, f, pb, pf, adj) in (
+ (qr_null!, qr_null, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint),
+ (lq_null!, lq_null, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint),
+ )
#forward mode not implemented yet
@eval begin
@is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@@ -140,18 +142,18 @@ for (f!, f, pb, pf, adj) in ((qr_null!, qr_null, qr_null_pullback!, qr_null_push
return output_codual, $adj
end
function Mooncake.frule!!(f_df::Dual{typeof($f!)}, A_dA::Dual, arg_darg::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm})
- A, dA = arrayify(A_dA)
- Ac = MatrixAlgebraKit.copy_input($f, A)
+ A, dA = arrayify(A_dA)
+ Ac = MatrixAlgebraKit.copy_input($f, A)
arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
- arg = $f!(A, arg, Mooncake.primal(alg_dalg))
+ arg = $f!(A, arg, Mooncake.primal(alg_dalg))
$pf(dA, Ac, arg, darg)
zero!(dA)
return arg_darg
end
function Mooncake.frule!!(f_df::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm})
- A, dA = arrayify(A_dA)
- arg = $f(A, Mooncake.primal(alg_dalg))
- darg = Mooncake.zero_tangent(arg)
+ A, dA = arrayify(A_dA)
+ arg = $f(A, Mooncake.primal(alg_dalg))
+ darg = Mooncake.zero_tangent(arg)
$pf(dA, A, arg, darg)
return Dual(arg, darg)
end
@@ -210,7 +212,7 @@ for (f!, f, f_full, pb, pf, adj) in (
# compute primal
A, dA = arrayify(A_dA)
fullD, V = $f_full(A, Mooncake.primal(alg_dalg))
- D_dD = Mooncake.zero_dual(diagview(fullD))
+ D_dD = Mooncake.zero_dual(diagview(fullD))
D, dD = arrayify(D_dD)
$pf(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing))
return D_dD
@@ -249,8 +251,8 @@ for (f, pb, pf, adj) in (
end
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual)
# compute primal
- A, dA = arrayify(A_dA)
- alg = Mooncake.primal(alg_dalg)
+ A, dA = arrayify(A_dA)
+ alg = Mooncake.primal(alg_dalg)
output = $f(A, alg)
output_dual = Mooncake.zero_dual(output)
dD_ = Mooncake.tangent(output_dual)[1]
@@ -335,26 +337,26 @@ for (f!, f) in (
end
function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual)
# compute primal
- USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
- dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
- A, dA = arrayify(A_dA)
+ USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
+ dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
+ A, dA = arrayify(A_dA)
$f!(A, USVᴴ, Mooncake.primal(alg_dalg))
# update tangents
- U_, S_, Vᴴ_ = USVᴴ
+ U_, S_, Vᴴ_ = USVᴴ
dU_, dS_, dVᴴ_ = dUSVᴴ
- U, dU = arrayify(U_, dU_)
- S, dS = arrayify(S_, dS_)
- Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_)
- minmn = min(size(A)...)
+ U, dU = arrayify(U_, dU_)
+ S, dS = arrayify(S_, dS_)
+ Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_)
+ minmn = min(size(A)...)
if $(f == svd_compact!) # compact
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
else # full
- vU = view(U, :, 1:minmn)
- vS = view(S, 1:minmn, 1:minmn)
- vVᴴ = view(Vᴴ, 1:minmn, :)
- vdU = view(dU, :, 1:minmn)
- vdS = view(dS, 1:minmn, 1:minmn)
- vdVᴴ = view(dVᴴ, 1:minmn, :)
+ vU = view(U, :, 1:minmn)
+ vS = view(S, 1:minmn, 1:minmn)
+ vVᴴ = view(Vᴴ, 1:minmn, :)
+ vdU = view(dU, :, 1:minmn)
+ vdS = view(dS, 1:minmn, 1:minmn)
+ vdVᴴ = view(dVᴴ, 1:minmn, :)
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
end
zero!(dA)
@@ -362,10 +364,10 @@ for (f!, f) in (
end
function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual)
# compute primal
- A, dA = arrayify(A_dA)
+ A, dA = arrayify(A_dA)
USVᴴ = $f(A, Mooncake.primal(alg_dalg))
# update tangents
- U, S, Vᴴ = USVᴴ
+ U, S, Vᴴ = USVᴴ
dU_ = Mooncake.zero_tangent(U)
dS_ = Mooncake.zero_tangent(S)
dVᴴ_ = Mooncake.zero_tangent(Vᴴ)
@@ -376,12 +378,12 @@ for (f!, f) in (
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
else # full
minmn = min(size(A)...)
- vU = view(U, :, 1:minmn)
- vS = view(S, 1:minmn, 1:minmn)
- vVᴴ = view(Vᴴ, 1:minmn, :)
- vdU = view(dU, :, 1:minmn)
- vdS = view(dS, 1:minmn, 1:minmn)
- vdVᴴ = view(dVᴴ, 1:minmn, :)
+ vU = view(U, :, 1:minmn)
+ vS = view(S, 1:minmn, 1:minmn)
+ vVᴴ = view(Vᴴ, 1:minmn, :)
+ vdU = view(dU, :, 1:minmn)
+ vdS = view(dS, 1:minmn, 1:minmn)
+ vdVᴴ = view(dVᴴ, 1:minmn, :)
svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
end
return Dual(USVᴴ, (dU_, dS_, dVᴴ_))
@@ -392,8 +394,8 @@ end
@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
- A, dA = arrayify(A_dA)
- alg = Mooncake.primal(alg_dalg)
+ A, dA = arrayify(A_dA)
+ alg = Mooncake.primal(alg_dalg)
output = svd_trunc(A, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
@@ -404,8 +406,8 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
- U, dU = arrayify(Utrunc, dUtrunc_)
- S, dS = arrayify(Strunc, dStrunc_)
+ U, dU = arrayify(Utrunc, dUtrunc_)
+ S, dS = arrayify(Strunc, dStrunc_)
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)
svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
zero!(dU)
@@ -417,23 +419,23 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
end
function Mooncake.frule!!(::Dual{typeof(svd_trunc)}, A_dA::Dual, alg_dalg::Dual)
# compute primal
- A, dA = Mooncake.arrayify(A_dA)
- alg = Mooncake.primal(alg_dalg)
- USVᴴ = svd_compact(A, alg.alg)
+ A, dA = Mooncake.arrayify(A_dA)
+ alg = Mooncake.primal(alg_dalg)
+ USVᴴ = svd_compact(A, alg.alg)
U, S, Vᴴ = USVᴴ
- dUfull = zeros(eltype(U), size(U))
- dSfull = Diagonal(zeros(eltype(S), length(diagview(S))))
+ dUfull = zeros(eltype(U), size(U))
+ dSfull = Diagonal(zeros(eltype(S), length(diagview(S))))
dVᴴfull = zeros(eltype(Vᴴ), size(Vᴴ))
svd_pushforward!(dA, A, (U, S, Vᴴ), (dUfull, dSfull, dVᴴfull))
USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc)
- ϵ = truncation_error!(diagview(S), ind)
- output = (USVᴴtrunc..., ϵ)
+ ϵ = truncation_error!(diagview(S), ind)
+ output = (USVᴴtrunc..., ϵ)
output_dual = Mooncake.zero_dual(output)
Utrunc, Strunc, Vᴴtrunc, ϵ = output
dU_, dS_, dVᴴ_, dϵ = Mooncake.tangent(output_dual)
- Utrunc, dU = arrayify(Utrunc, dU_)
- Strunc, dS = arrayify(Strunc, dS_)
+ Utrunc, dU = arrayify(Utrunc, dU_)
+ Strunc, dS = arrayify(Strunc, dS_)
Vᴴtrunc, dVᴴ = arrayify(Vᴴtrunc, dVᴴ_)
dU .= view(dUfull, :, ind)
diagview(dS) .= view(diagview(dSfull), ind)
diff --git a/src/common/view.jl b/src/common/view.jl
index 0bc7b9e..c8ae1aa 100644
--- a/src/common/view.jl
+++ b/src/common/view.jl
@@ -1,5 +1,5 @@
# diagind: provided by LinearAlgebra.jl
-diagview(D::Diagonal) = D.diag
+diagview(D::Diagonal) = D.diag
diagview(D::AbstractMatrix) = view(D, diagind(D))
# triangularind
diff --git a/src/pushforwards/eig.jl b/src/pushforwards/eig.jl
index 6609411..47b4710 100644
--- a/src/pushforwards/eig.jl
+++ b/src/pushforwards/eig.jl
@@ -1,12 +1,12 @@
function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...)
- D, V = DV
- ΔD, ΔV = ΔDV
- iVΔAV = inv(V) * ΔA * V
+ D, V = DV
+ ΔD, ΔV = ΔDV
+ iVΔAV = inv(V) * ΔA * V
diagview(ΔD) .= diagview(iVΔAV)
if !iszerotangent(ΔV)
- F = 1 ./ (transpose(diagview(D)) .- diagview(D))
+ F = 1 ./ (transpose(diagview(D)) .- diagview(D))
fill!(diagview(F), zero(eltype(F)))
- K̇ = F .* iVΔAV
+ K̇ = F .* iVΔAV
mul!(ΔV, V, K̇, 1, 0)
end
return ΔDV
diff --git a/src/pushforwards/eigh.jl b/src/pushforwards/eigh.jl
index edf418a..d5d663d 100644
--- a/src/pushforwards/eigh.jl
+++ b/src/pushforwards/eigh.jl
@@ -1,16 +1,16 @@
function eigh_pushforward!(dA, A, DV, dDV; kwargs...)
- D, V = DV
- dD, dV = dDV
- tmpV = V \ dA
- ∂K = tmpV * V
- ∂Kdiag = diag(∂K)
+ D, V = DV
+ dD, dV = dDV
+ tmpV = V \ dA
+ ∂K = tmpV * V
+ ∂Kdiag = diag(∂K)
diagview(dD) .= real.(∂Kdiag)
if !iszerotangent(dV)
- dDD = transpose(diagview(D)) .- diagview(D)
- F = one(eltype(dDD)) ./ dDD
+ dDD = transpose(diagview(D)) .- diagview(D)
+ F = one(eltype(dDD)) ./ dDD
diagview(F) .= zero(eltype(F))
- ∂K .*= F
- ∂V = mul!(tmpV, V, ∂K)
+ ∂K .*= F
+ ∂V = mul!(tmpV, V, ∂K)
copyto!(dV, ∂V)
end
return (dD, dV)
diff --git a/src/pushforwards/lq.jl b/src/pushforwards/lq.jl
index 2f5c0e5..6490e1e 100644
--- a/src/pushforwards/lq.jl
+++ b/src/pushforwards/lq.jl
@@ -1,7 +1,7 @@
-function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=default_pullback_gauge_atol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol)
- qr_pushforward!(adjoint(dA), adjoint(A), adjoint.(reverse(LQ)), adjoint.(reverse(dLQ)); tol, rank_atol, gauge_atol)
+function lq_pushforward!(dA, A, LQ, dLQ; tol::Real = default_pullback_gauge_atol(LQ[1]), rank_atol::Real = tol, gauge_atol::Real = tol)
+ return qr_pushforward!(adjoint(dA), adjoint(A), adjoint.(reverse(LQ)), adjoint.(reverse(dLQ)); tol, rank_atol, gauge_atol)
end
-function lq_null_pushforward!(dA, A, Nᴴ, dNᴴ; tol::Real=default_pullback_gauge_atol(Nᴴ), rank_atol::Real=tol, gauge_atol::Real=tol)
- iszero(min(size(Nᴴ)...)) && return # nothing to do
+function lq_null_pushforward!(dA, A, Nᴴ, dNᴴ; tol::Real = default_pullback_gauge_atol(Nᴴ), rank_atol::Real = tol, gauge_atol::Real = tol)
+ return iszero(min(size(Nᴴ)...)) && return # nothing to do
end
diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl
index e8f89bb..1e0da1b 100644
--- a/src/pushforwards/polar.jl
+++ b/src/pushforwards/polar.jl
@@ -1,21 +1,21 @@
function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...)
- W, P = WP
+ W, P = WP
ΔW, ΔP = ΔWP
- aWdA = adjoint(W) * ΔA
- K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA)))
- L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W*adjoint(W))*ΔA*inv(P)
- ΔW .= W * K̇ + L̇
- ΔP .= aWdA - K̇*P
+ aWdA = adjoint(W) * ΔA
+ K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA)))
+ L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W * adjoint(W)) * ΔA * inv(P)
+ ΔW .= W * K̇ + L̇
+ ΔP .= aWdA - K̇ * P
return (ΔW, ΔP)
end
function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...)
- P, Wᴴ = PWᴴ
+ P, Wᴴ = PWᴴ
ΔP, ΔWᴴ = ΔPWᴴ
- dAW = ΔA * adjoint(Wᴴ)
- K̇ = sylvester(P, P, -(dAW - adjoint(dAW)))
- L̇ = inv(P)*ΔA*(Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ)
- ΔWᴴ .= K̇ * Wᴴ + L̇
- ΔP .= dAW - P * K̇
+ dAW = ΔA * adjoint(Wᴴ)
+ K̇ = sylvester(P, P, -(dAW - adjoint(dAW)))
+ L̇ = inv(P) * ΔA * (Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ)
+ ΔWᴴ .= K̇ * Wᴴ + L̇
+ ΔP .= dAW - P * K̇
return (ΔWᴴ, ΔP)
end
diff --git a/src/pushforwards/qr.jl b/src/pushforwards/qr.jl
index aba05b0..3778119 100644
--- a/src/pushforwards/qr.jl
+++ b/src/pushforwards/qr.jl
@@ -1,10 +1,10 @@
-function qr_pushforward!(dA, A, QR, dQR; tol::Real=default_pullback_gauge_atol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol)
- Q, R = QR
- m = size(A, 1)
- n = size(A, 2)
+function qr_pushforward!(dA, A, QR, dQR; tol::Real = default_pullback_gauge_atol(QR[2]), rank_atol::Real = tol, gauge_atol::Real = tol)
+ Q, R = QR
+ m = size(A, 1)
+ n = size(A, 2)
minmn = min(m, n)
- Rd = diagview(R)
- p = findlast(>=(rank_atol) ∘ abs, Rd)
+ Rd = diagview(R)
+ p = findlast(>=(rank_atol) ∘ abs, Rd)
m1 = p
m2 = minmn - p
@@ -12,50 +12,50 @@ function qr_pushforward!(dA, A, QR, dQR; tol::Real=default_pullback_gauge_atol(Q
n1 = p
n2 = n - p
- Q1 = view(Q, 1:m, 1:m1) # full rank portion
- Q2 = view(Q, 1:m, m1+1:m2+m1)
+ Q1 = view(Q, 1:m, 1:m1) # full rank portion
+ Q2 = view(Q, 1:m, (m1 + 1):(m2 + m1))
R11 = view(R, 1:m1, 1:n1)
- R12 = view(R, 1:m1, n1+1:n)
+ R12 = view(R, 1:m1, (n1 + 1):n)
dA1 = view(dA, 1:m, 1:n1)
dA2 = view(dA, 1:m, (n1 + 1):n)
dQ, dR = dQR
- dQ1 = view(dQ, 1:m, 1:m1)
- dQ2 = view(dQ, 1:m, m1+1:m2+m1)
- dQ3 = minmn+1 < size(dQ, 2) ? view(dQ, :, minmn+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0))
- dR11 = view(dR, 1:m1, 1:n1)
- dR12 = view(dR, 1:m1, n1+1:n)
- dR22 = view(dR, m1+1:m1+m2, n1+1:n)
+ dQ1 = view(dQ, 1:m, 1:m1)
+ dQ2 = view(dQ, 1:m, (m1 + 1):(m2 + m1))
+ dQ3 = minmn + 1 < size(dQ, 2) ? view(dQ, :, (minmn + 1):size(dQ, 2)) : similar(dQ, eltype(dQ), (0, 0))
+ dR11 = view(dR, 1:m1, 1:n1)
+ dR12 = view(dR, 1:m1, (n1 + 1):n)
+ dR22 = view(dR, (m1 + 1):(m1 + m2), (n1 + 1):n)
# fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need
- invR11 = inv(R11)
- tmp = Q1' * dA1 * invR11
- Rtmp = tmp + tmp'
+ invR11 = inv(R11)
+ tmp = Q1' * dA1 * invR11
+ Rtmp = tmp + tmp'
diagview(Rtmp) ./= 2
- ltRtmp = view(Rtmp, lowertriangularind(Rtmp))
+ ltRtmp = view(Rtmp, lowertriangularind(Rtmp))
ltRtmp .= zero(eltype(Rtmp))
- dR11 .= Rtmp * R11
- dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11
- dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12)
+ dR11 .= Rtmp * R11
+ dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11
+ dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12)
if size(Q2, 2) > 0
- dQ2 .= -Q1 * (Q1' * Q2)
- dQ2 .+= Q2 * (Q2' * dQ2)
+ dQ2 .= -Q1 * (Q1' * Q2)
+ dQ2 .+= Q2 * (Q2' * dQ2)
end
if m3 > 0 && size(Q, 2) > minmn
# only present for qr_full or rank-deficient qr_compact
- Q′ = view(Q, :, 1:minmn)
- Q3 = view(Q, :, minmn+1:m)
+ Q′ = view(Q, :, 1:minmn)
+ Q3 = view(Q, :, (minmn + 1):m)
#dQ3 .= Q′ * (Q′' * Q3)
- dQ3 .= Q3
+ dQ3 .= Q3
end
if !isempty(dR22)
- _, r22 = qr_compact(dA2 - dQ1*R12 - Q1*dR12; positive=true)
- dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2))
+ _, r22 = qr_compact(dA2 - dQ1 * R12 - Q1 * dR12; positive = true)
+ dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2))
end
return (dQ, dR)
end
-function qr_null_pushforward!(dA, A, N, dN; tol::Real=default_pullback_gauge_atol(N), rank_atol::Real=tol, gauge_atol::Real=tol)
- iszero(min(size(N)...)) && return # nothing to do
+function qr_null_pushforward!(dA, A, N, dN; tol::Real = default_pullback_gauge_atol(N), rank_atol::Real = tol, gauge_atol::Real = tol)
+ return iszero(min(size(N)...)) && return # nothing to do
end
diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl
index d8591d9..f4b547e 100644
--- a/src/pushforwards/svd.jl
+++ b/src/pushforwards/svd.jl
@@ -1,82 +1,82 @@
-function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol=default_pullback_rank_atol(A), kwargs...)
+function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol = default_pullback_rank_atol(A), kwargs...)
U, Smat, Vᴴ = USVᴴ
- m, n = size(U, 1), size(Vᴴ, 2)
+ m, n = size(U, 1), size(Vᴴ, 2)
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
minmn = min(m, n)
- S = diagview(Smat)
+ S = diagview(Smat)
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
r = searchsortedlast(S, rank_atol; rev = true) # rank
- vΔU = view(ΔU, :, 1:r)
- vΔS = view(ΔS, 1:r, 1:r)
+ vΔU = view(ΔU, :, 1:r)
+ vΔS = view(ΔS, 1:r, 1:r)
vΔVᴴ = view(ΔVᴴ, 1:r, :)
- vU = view(U, :, 1:r)
- vS = view(S, 1:r)
+ vU = view(U, :, 1:r)
+ vS = view(S, 1:r)
vSmat = view(Smat, 1:r, 1:r)
- vVᴴ = view(Vᴴ, 1:r, :)
+ vVᴴ = view(Vᴴ, 1:r, :)
# compact region
- vV = adjoint(vVᴴ)
- UΔAV = vU' * ΔA * vV
+ vV = adjoint(vVᴴ)
+ UΔAV = vU' * ΔA * vV
copyto!(diagview(vΔS), diag(real.(UΔAV)))
- F = one(eltype(S)) ./ (transpose(vS) .- vS)
- G = one(eltype(S)) ./ (transpose(vS) .+ vS)
+ F = one(eltype(S)) ./ (transpose(vS) .- vS)
+ G = one(eltype(S)) ./ (transpose(vS) .+ vS)
diagview(F) .= zero(eltype(F))
hUΔAV = F .* (UΔAV + UΔAV') ./ 2
aUΔAV = G .* (UΔAV - UΔAV') ./ 2
- K̇ = hUΔAV + aUΔAV
- Ṁ = hUΔAV - aUΔAV
+ K̇ = hUΔAV + aUΔAV
+ Ṁ = hUΔAV - aUΔAV
# check gauge condition
@assert isantihermitian(K̇)
@assert isantihermitian(Ṁ)
K̇diag = diagview(K̇)
for i in 1:length(K̇diag)
- @assert K̇diag[i] ≈ (im/2) * imag(diagview(UΔAV)[i])/S[i]
+ @assert K̇diag[i] ≈ (im / 2) * imag(diagview(UΔAV)[i]) / S[i]
end
- ∂U = vU * K̇
- ∂V = vV * Ṁ
+ ∂U = vU * K̇
+ ∂V = vV * Ṁ
# full component
if size(U, 2) > minmn && size(Vᴴ, 1) > minmn
- Uperp = view(U, :, minmn+1:m)
- Vᴴperp = view(Vᴴ, minmn+1:n, :)
+ Uperp = view(U, :, (minmn + 1):m)
+ Vᴴperp = view(Vᴴ, (minmn + 1):n, :)
- aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp)
+ aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp)
- UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2)))
+ UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2)))
fill!(UÃÃV, 0)
view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV
- view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV'
- rhs = vcat( adjoint(Uperp, ΔA, V), Vᴴperp * ΔA' * U)
+ view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV'
+ rhs = vcat(adjoint(Uperp, ΔA, V), Vᴴperp * ΔA' * U)
superKM = -sylvester(UÃÃV, Smat, rhs)
K̇perp = view(superKM, 1:size(aUAV, 2))
- Ṁperp = view(superKM, size(aUAV, 2)+1:size(aUAV, 1)+size(aUAV, 2))
- ∂U .+= Uperp * K̇perp
- ∂V .+= Vperp * Ṁperp
+ Ṁperp = view(superKM, (size(aUAV, 2) + 1):(size(aUAV, 1) + size(aUAV, 2)))
+ ∂U .+= Uperp * K̇perp
+ ∂V .+= Vperp * Ṁperp
else
- ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU*vU')
- ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV*vVᴴ)
- upper = ImUU * ΔA * vV
+ ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU * vU')
+ ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV * vVᴴ)
+ upper = ImUU * ΔA * vV
lower = ImVV * ΔA' * vU
- rhs = vcat(upper, lower)
-
- Ã = ImUU * A * ImVV
- ÃÃ = similar(A, (m + n, m + n))
+ rhs = vcat(upper, lower)
+
+ Ã = ImUU * A * ImVV
+ ÃÃ = similar(A, (m + n, m + n))
fill!(ÃÃ, 0)
view(ÃÃ, (1:m), m .+ (1:n)) .= Ã
- view(ÃÃ, m .+ (1:n), 1:m ) .= Ã'
+ view(ÃÃ, m .+ (1:n), 1:m) .= Ã'
superLN = -sylvester(ÃÃ, vSmat, rhs)
- ∂U += view(superLN, 1:size(upper, 1), :)
- ∂V += view(superLN, size(upper, 1)+1:size(upper,1)+size(lower,1), :)
+ ∂U += view(superLN, 1:size(upper, 1), :)
+ ∂V += view(superLN, (size(upper, 1) + 1):(size(upper, 1) + size(lower, 1)), :)
end
- copyto!(vΔU, ∂U)
+ copyto!(vΔU, ∂U)
adjoint!(vΔVᴴ, ∂V)
return (ΔU, ΔS, ΔVᴴ)
end
-function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol=default_pullback_rank_atol(A), kwargs...)
+function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...)
end
diff --git a/test/mooncake.jl b/test/mooncake.jl
index 23e5b07..9aae281 100644
--- a/test/mooncake.jl
+++ b/test/mooncake.jl
@@ -22,7 +22,7 @@ make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), T...)
make_mooncake_fdata(x) = make_mooncake_tangent(x)
make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),))
-ETs = (Float64, ComplexF64,)# Float32,)# ComplexF64, ComplexF32)
+ETs = (Float64, ComplexF64) # Float32,)# ComplexF64, ComplexF32)
# no `alg` argument
function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata)
@@ -119,8 +119,8 @@ end
rng = StableRNG(12345)
m = 19
@testset "size ($m, $n)" for n in (17, m, 23)
- atol = rtol = m * n * precision(T)
- A = randn(rng, T, m, n)
+ atol = rtol = m * n * precision(T)
+ A = randn(rng, T, m, n)
minmn = min(m, n)
@testset for alg in (
LAPACK_HouseholderQR(),
@@ -128,9 +128,9 @@ end
)
@testset "qr_compact" begin
QR = qr_compact(A, alg)
- Q = randn(rng, T, m, minmn)
- R = randn(rng, T, minmn, n)
- Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; atol=atol, rtol=rtol)
+ Q = randn(rng, T, m, minmn)
+ R = randn(rng, T, minmn, n)
+ Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; atol = atol, rtol = rtol)
test_pullbacks_match(rng, qr_compact!, qr_compact, A, (Q, R), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg)
end
@testset "qr_null" begin
@@ -138,46 +138,46 @@ end
ΔN = Q * randn(rng, T, minmn, max(0, m - minmn))
N = qr_null(A, alg)
dN = make_mooncake_tangent(copy(ΔN))
- Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; output_tangent = dN, atol=atol, rtol=rtol)
+ Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; output_tangent = dN, atol = atol, rtol = rtol)
test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg)
end
@testset "qr_full" begin
Q, R = qr_full(A, alg)
- Q1 = view(Q, 1:m, 1:minmn)
- ΔQ = randn(rng, T, m, m)
- ΔQ2 = view(ΔQ, :, (minmn + 1):m)
+ Q1 = view(Q, 1:m, 1:minmn)
+ ΔQ = randn(rng, T, m, m)
+ ΔQ2 = view(ΔQ, :, (minmn + 1):m)
mul!(ΔQ2, Q1, Q1' * ΔQ2)
- ΔR = randn(rng, T, m, n)
- dQ = make_mooncake_tangent(copy(ΔQ))
- dR = make_mooncake_tangent(copy(ΔR))
- dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR)
+ ΔR = randn(rng, T, m, n)
+ dQ = make_mooncake_tangent(copy(ΔQ))
+ dR = make_mooncake_tangent(copy(ΔR))
+ dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR)
#Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; output_tangent = dQR, atol=atol, rtol=rtol)
- Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[2]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
- Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, 1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
- Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, minmn+1:m]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
+ Mooncake.TestUtils.test_rule(rng, ((A, alg) -> qr_full(A, alg)[2]), A, alg; mode = Mooncake.ForwardMode, is_primitive = false, atol = atol, rtol = rtol)
+ Mooncake.TestUtils.test_rule(rng, ((A, alg) -> qr_full(A, alg)[1][1:m, 1:minmn]), A, alg; mode = Mooncake.ForwardMode, is_primitive = false, atol = atol, rtol = rtol)
+ Mooncake.TestUtils.test_rule(rng, ((A, alg) -> qr_full(A, alg)[1][1:m, (minmn + 1):m]), A, alg; mode = Mooncake.ForwardMode, is_primitive = false, atol = atol, rtol = rtol)
test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg)
end
@testset "qr_compact - rank-deficient A" begin
- r = minmn - 5
- Ard = randn(rng, T, m, r) * randn(rng, T, r, n)
+ r = minmn - 5
+ Ard = randn(rng, T, m, r) * randn(rng, T, r, n)
Q, R = qr_compact(Ard, alg)
- QR = (Q, R)
- ΔQ = randn(rng, T, m, minmn)
- Q1 = view(Q, 1:m, 1:r)
- Q2 = view(Q, 1:m, (r + 1):minmn)
- ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn)
+ QR = (Q, R)
+ ΔQ = randn(rng, T, m, minmn)
+ Q1 = view(Q, 1:m, 1:r)
+ Q2 = view(Q, 1:m, (r + 1):minmn)
+ ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn)
ΔQ2 .= 0
- ΔR = randn(rng, T, minmn, n)
+ ΔR = randn(rng, T, minmn, n)
view(ΔR, (r + 1):minmn, :) .= 0
- dQ = make_mooncake_tangent(copy(ΔQ))
- dR = make_mooncake_tangent(copy(ΔR))
- dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR)
- Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; output_tangent = dQR, atol=atol, rtol=rtol)
- Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[2]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
- Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][1:r, 1:r]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
- Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][r+1:m, 1:r]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
- Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][1:r, r+1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
- Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][r+1:m, r+1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol)
+ dQ = make_mooncake_tangent(copy(ΔQ))
+ dR = make_mooncake_tangent(copy(ΔR))
+ dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR)
+ Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; output_tangent = dQR, atol = atol, rtol = rtol)
+ Mooncake.TestUtils.test_rule(rng, ((A, alg) -> qr_compact(A, alg)[2]), A, alg; mode = Mooncake.ForwardMode, is_primitive = false, atol = atol, rtol = rtol)
+ Mooncake.TestUtils.test_rule(rng, ((A, alg) -> qr_compact(A, alg)[1][1:r, 1:r]), A, alg; mode = Mooncake.ForwardMode, is_primitive = false, atol = atol, rtol = rtol)
+ Mooncake.TestUtils.test_rule(rng, ((A, alg) -> qr_compact(A, alg)[1][(r + 1):m, 1:r]), A, alg; mode = Mooncake.ForwardMode, is_primitive = false, atol = atol, rtol = rtol)
+ Mooncake.TestUtils.test_rule(rng, ((A, alg) -> qr_compact(A, alg)[1][1:r, (r + 1):minmn]), A, alg; mode = Mooncake.ForwardMode, is_primitive = false, atol = atol, rtol = rtol)
+ Mooncake.TestUtils.test_rule(rng, ((A, alg) -> qr_compact(A, alg)[1][(r + 1):m, (r + 1):minmn]), A, alg; mode = Mooncake.ForwardMode, is_primitive = false, atol = atol, rtol = rtol)
test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg)
end
end
@@ -202,39 +202,39 @@ end
end
@testset "lq_null" begin
L, Q = lq_compact(A, alg)
- ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q
- Nᴴ = randn(rng, T, max(0, n - minmn), n)
- dNᴴ = make_mooncake_tangent(ΔNᴴ)
+ ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q
+ Nᴴ = randn(rng, T, max(0, n - minmn), n)
+ dNᴴ = make_mooncake_tangent(ΔNᴴ)
Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; output_tangent = dNᴴ, atol = atol, rtol = rtol)
test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg)
end
@testset "lq_full" begin
L, Q = lq_full(A, alg)
- Q1 = view(Q, 1:minmn, 1:n)
- ΔQ = randn(rng, T, n, n)
- ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n)
+ Q1 = view(Q, 1:minmn, 1:n)
+ ΔQ = randn(rng, T, n, n)
+ ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n)
mul!(ΔQ2, ΔQ2 * Q1', Q1)
- ΔL = randn(rng, T, m, n)
- dL = make_mooncake_tangent(ΔL)
- dQ = make_mooncake_tangent(ΔQ)
- dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ)
+ ΔL = randn(rng, T, m, n)
+ dL = make_mooncake_tangent(ΔL)
+ dQ = make_mooncake_tangent(ΔQ)
+ dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ)
Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; output_tangent = dLQ, atol = atol, rtol = rtol)
test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg)
end
@testset "lq_compact - rank-deficient A" begin
- r = minmn - 5
- Ard = randn(rng, T, m, r) * randn(rng, T, r, n)
+ r = minmn - 5
+ Ard = randn(rng, T, m, r) * randn(rng, T, r, n)
L, Q = lq_compact(Ard, alg)
- ΔL = randn(rng, T, m, minmn)
- ΔQ = randn(rng, T, minmn, n)
- Q1 = view(Q, 1:r, 1:n)
- Q2 = view(Q, (r + 1):minmn, 1:n)
- ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n)
+ ΔL = randn(rng, T, m, minmn)
+ ΔQ = randn(rng, T, minmn, n)
+ Q1 = view(Q, 1:r, 1:n)
+ Q2 = view(Q, (r + 1):minmn, 1:n)
+ ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n)
ΔQ2 .= 0
view(ΔL, :, (r + 1):minmn) .= 0
- dL = make_mooncake_tangent(ΔL)
- dQ = make_mooncake_tangent(ΔQ)
- dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ)
+ dL = make_mooncake_tangent(ΔL)
+ dQ = make_mooncake_tangent(ΔQ)
+ dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ)
Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; output_tangent = dLQ, atol = atol, rtol = rtol)
test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg)
end
@@ -269,7 +269,7 @@ end
test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg)
end
@testset "eig_trunc" begin
- Ah = (A + A')/2
+ Ah = (A + A') / 2
for r in 1:4:m
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc)
@@ -448,42 +448,42 @@ end
ΔS = randn(rng, real(T), minmn, minmn)
ΔS2 = Diagonal(randn(rng, real(T), minmn))
ΔVᴴ = randn(rng, T, minmn, n)
- ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
+ ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
truncalg = TruncatedAlgorithm(alg, truncrank(r))
- ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
- Strunc = Diagonal(diagview(S)[ind])
- Utrunc = U[:, ind]
- Vᴴtrunc = Vᴴ[ind, :]
- ΔStrunc = Diagonal(diagview(ΔS2)[ind])
- ΔUtrunc = ΔU[:, ind]
+ ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
+ Strunc = Diagonal(diagview(S)[ind])
+ Utrunc = U[:, ind]
+ Vᴴtrunc = Vᴴ[ind, :]
+ ΔStrunc = Diagonal(diagview(ΔS2)[ind])
+ ΔUtrunc = ΔU[:, ind]
ΔVᴴtrunc = ΔVᴴ[ind, :]
- dStrunc = make_mooncake_tangent(ΔStrunc)
- dUtrunc = make_mooncake_tangent(ΔUtrunc)
+ dStrunc = make_mooncake_tangent(ΔStrunc)
+ dUtrunc = make_mooncake_tangent(ΔUtrunc)
dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc)
- ϵ = zero(real(T))
+ ϵ = zero(real(T))
dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ)
Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
end
@testset "trunctol" begin
U, S, Vᴴ = svd_compact(A)
- ΔU = randn(rng, T, m, minmn)
- ΔS = randn(rng, real(T), minmn, minmn)
- ΔS2 = Diagonal(randn(rng, real(T), minmn))
- ΔVᴴ = randn(rng, T, minmn, n)
- ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
+ ΔU = randn(rng, T, m, minmn)
+ ΔS = randn(rng, real(T), minmn, minmn)
+ ΔS2 = Diagonal(randn(rng, real(T), minmn))
+ ΔVᴴ = randn(rng, T, minmn, n)
+ ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2))
- ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
- Strunc = Diagonal(diagview(S)[ind])
- Utrunc = U[:, ind]
- Vᴴtrunc = Vᴴ[ind, :]
- ΔStrunc = Diagonal(diagview(ΔS2)[ind])
- ΔUtrunc = ΔU[:, ind]
+ ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
+ Strunc = Diagonal(diagview(S)[ind])
+ Utrunc = U[:, ind]
+ Vᴴtrunc = Vᴴ[ind, :]
+ ΔStrunc = Diagonal(diagview(ΔS2)[ind])
+ ΔUtrunc = ΔU[:, ind]
ΔVᴴtrunc = ΔVᴴ[ind, :]
- dStrunc = make_mooncake_tangent(ΔStrunc)
- dUtrunc = make_mooncake_tangent(ΔUtrunc)
+ dStrunc = make_mooncake_tangent(ΔStrunc)
+ dUtrunc = make_mooncake_tangent(ΔUtrunc)
dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc)
- ϵ = zero(real(T))
+ ϵ = zero(real(T))
dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ)
Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; output_tangent = dUSVᴴerr, atol = atol, rtol = rtol)
test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
@@ -520,12 +520,12 @@ right_orth_lq(X) = right_orth(X; alg = :lq)
right_orth_polar(X) = right_orth(X; alg = :polar)
right_null_lq(X) = right_null(X; alg = :lq)
-MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A)
-MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A)
-MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A)
-MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A)
+MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A)
+MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A)
+MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A)
+MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A)
MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A)
-MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A)
+MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A)
@timedtestset "Orth and null with eltype $T" for T in ETs
rng = StableRNG(12345) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.