Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set AD rules #93

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft

Conversation

elisno
Copy link
Contributor

@elisno elisno commented Apr 3, 2021

Resolves #92.

@codecov-io
Copy link

codecov-io commented Apr 3, 2021

Codecov Report

Merging #93 (acc9a90) into master (4967a5f) will decrease coverage by 1.25%.
The diff coverage is 69.23%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #93      +/-   ##
==========================================
- Coverage   90.96%   89.71%   -1.26%     
==========================================
  Files          11       11              
  Lines         620      632      +12     
==========================================
+ Hits          564      567       +3     
- Misses         56       65       +9     
Impacted Files Coverage Δ
src/chainrules.jl 69.23% <69.23%> (ø)
src/vectrick.jl 90.55% <0.00%> (-3.94%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4967a5f...acc9a90. Read the comment docs.

@MichielStock
Copy link
Owner

MichielStock commented Apr 6, 2021

Your code does not seem to work for some examples and gives the wrong result for others.

For example:

gradient((A, B)->sum(AB), A, B)
gradient((A, B)->sum(kron(A,B)), A, B)

Most of our Kronecker functions fall back on regular function Zygote etc should be able to handle fine. It works for logdet but not for tr and sum (which work with the native kron, e.g. gradient((A, B)-> sum(kron(A, B)), A, B). Not sure why or how we can make ChainRulesCore fall back on this underlying code.

Do you have a reference for your gradients?

Adds a testing-function for different 'output' dimensions of each factor in the Kronecker product. It defines linear regression models with the sum of squared residuals as a loss function. Currently only works for residuals of scalar outputs. Tests are broken for outputs of higher dimensions.
@elisno
Copy link
Contributor Author

elisno commented Apr 6, 2021

Your code does not seem to work for some examples and gives the wrong result for others.

You're right, I started with the following loss function:

function loss(A, B, X)
    Z = kron(A, B)*X - y
    L = 0.5 * tr(Z' * Z)
    return L
end

where y has size (1, num_samples).
I wrote kronecker_product_pullback in the rrule with this in mind, but forgot that each sample in y can have a higher dimension.

In test/testchainrules.jl, I make a comparison of Zygote.gradient with hand-written gradients for this trivial case. I do another comparison with kronecker.

I decided to leave similar tests for higher-dimensions, but leave them with @test_broken for now.

@elisno
Copy link
Contributor Author

elisno commented Apr 6, 2021

I've been experimenting with KroneckerSum as well.

I managed to get the correct values for the pullback:

function ChainRulesCore.frule((_, ΔA, ΔB), ::KroneckerSum, A::AbstractMatrix, B::AbstractMatrix)
    Ω = (A  B)
    ∂Ω = (ΔA  ΔB)
    return Ω, ∂Ω
end

function ChainRulesCore.rrule(::typeof(KroneckerSum), A::AbstractMatrix, B::AbstractMatrix)
    function kronecker_sum_pullback(ΔΩ)
        ∂A = nB .* A + Diagonal(fill(tr(B), nA))
        ∂B = nA .* B + Diagonal(fill(tr(A), nB))
        return (NO_FIELDS, ∂A, ∂B)
    end
    return (A  B), kronecker_sum_pullback
end

nA = 3
nB = 2
Ar = rand(nA,nA)
Br = rand(nB,nB)
Y_lazy, back_lazy = Zygote._pullback(, Ar, Br)
Y, back = Zygote._pullback((x,y) -> kron(x, Diagonal(ones(nB))) + kron(Diagonal(ones(nA)), y), Ar, Br)
julia> back(Y)[2:end] .≈ back_lazy(Y_lazy)[2:end]
(true, true)

Of course, this isn't useful for computing the gradient in more complicated expressions, since ΔΩ is not used in computing either ∂A or ∂B in the rrule.

@elisno
Copy link
Contributor Author

elisno commented Apr 6, 2021

Note that:

ChainRulesCore.rrule(::typeof(KroneckerSum), A::AbstractMatrix, B::AbstractMatrix)

overwrites

ChainRulesCore.rrule(::typeof(KroneckerProduct), A::AbstractMatrix, B::AbstractMatrix)

Should I use something else instead of ::typeof(KroneckerProduct)/::typeof(KroneckerSum)?

@MichielStock
Copy link
Owner

Still stuck on this, why does computing gradients work for logdet but not tr or sum. It should just fall back to the simple shortcuts, for which adjoints already exist?

@MichielStock
Copy link
Owner

Technically, it only makes sense to define the adjoints for those function where Kronecker provides shortcuts, based on this rule: https://en.wikipedia.org/wiki/Matrix_calculus#Identities_in_differential_form

@elisno
Copy link
Contributor Author

elisno commented Apr 12, 2021

Still stuck on this, why does computing gradients work for logdet

Can you provide a MWE for logdet?

@elisno
Copy link
Contributor Author

elisno commented Apr 12, 2021

Technically, it only makes sense to define the adjoints for those function where Kronecker provides shortcuts, based on this rule: https://en.wikipedia.org/wiki/Matrix_calculus#Identities_in_differential_form

Maybe I misunderstood, but doesn't this only provide the frules?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AD rules that apply to KroneckerProducts
3 participants