Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD rules that apply to KroneckerProducts #92

Open
elisno opened this issue Mar 31, 2021 · 2 comments · May be fixed by #93
Open

AD rules that apply to KroneckerProducts #92

elisno opened this issue Mar 31, 2021 · 2 comments · May be fixed by #93

Comments

@elisno
Copy link
Contributor

elisno commented Mar 31, 2021

(Related to #11)

I'm trying to wrap my head around getting gradients with kron/kronecker.

  1. Is it sufficient to define custom AD rules for the vec-trick with ChainRulesCore.jl
function rrule(::typeof(*), K::KroneckerProduct, x::AbstractVector)
    function times_vec_pullback(ΔΩ)
        ...
    end
    return K*x, times_vec_pullback
end

function rrule(::typeof(*), K::KroneckerProduct, X::AbstractMatrix)
    function times_mat_pullback(ΔΩ)
        ...
    end
    return K*X, times_mat_pullback
end
  1. Do we also need to define rules for the constructor as well to get gradients?
function rrule(::typeof(kronecker), A::AbstractMatrix, B::AbstractMatrix)
    function kronecker_pullback(ΔΩ)
        ...
    end
    return kronecker(A, B), kronecker_pullback
end
  1. Should the pullbacks also be lazy? I found this to be a decent overview on finding vectorized derivatives. Would the pullbacks then just be reshape rules for these vectorized derivatives?
@MichielStock
Copy link
Owner

The question might be what is fixed and what you might want to compute the derivative of. I originally conceived Kronecker to work with systems as f(K * w) where you might want to optimize w as a parameter matrix. This should be easy enough.

Taking the gradients of the Kronecker matrix itself would be a A ⊗ B => I ⊗ B and A ⊗ I.

Maybe the dot(x, A, y) might also be a special case?

I have been working with ChainRulesCore, so you might open a PR and we can look together?

@elisno
Copy link
Contributor Author

elisno commented Apr 2, 2021

Taking the gradients of the Kronecker matrix itself would be a A ⊗ B => I ⊗ B and A ⊗ I.

It doesn't appear to be quite that straight forward. Care must be taken on setting the appropriate size of I for each partial derivative. A (conjugate?) transpose needs to be take of some of the matrices.

I have been working with ChainRulesCore, so you might open a PR and we can look together?

I've managed to put together a semi-working example with the eager kron and Zygote.gradient. I'd have to review how I do the first steps with the chain-rule. I'll open a PR today.

using LinearAlgebra
using Random
using Zygote

M, N = 3, 2
n_samples = 3

Random.seed!(0)
A = rand(1, N)
B = rand(1, M)
x = rand(M*N, n_samples)
y = rand(n_samples)

model(A, B, X) = kron(A, B) * X

function loss(A, B, X)
    Z = model(A, B, X) - y'
    L = 0.5 * Z * Z'
    return L[1]
end

function gradient_A(A, B, x)
    Z = model(A, B, x) - y'
    n = size(A, 2)
    IA_col = Diagonal(ones(n))
    return Z * (kron(IA_col', B) * x)'
end

function gradient_B(A, B, x)
    Z = model(A, B, x) - y'
    n = size(B, 2)
    IB_col = Diagonal(ones(n))
    return  Z * (kron(A, IB_col) * x)'
end

# Compare hand-written gradients with running Zygote.gradient on the loss function
@assert gradient_A(A, B, x)  gradient(loss, A, B, x)[1]
@assert gradient_B(A, B, x)  gradient(loss, A, B, x)[2]

# Show partial derivatives of the loss function w.r.t. to the Kronecker-factors.
@show gradient(loss, A, B, x)[1:2]

Maybe the dot(x, A, y) might also be a special case?

What did you have in mind for this?

@elisno elisno linked a pull request Apr 3, 2021 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants