diff --git a/Project.toml b/Project.toml index befe7e2..2d94d10 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["MichielStock "] version = "0.4.3" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f" @@ -19,8 +20,9 @@ julia = "1" [extras] PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] benchmark = ["PkgBenchmark"] -test = ["Random", "Test"] +test = ["Random", "Test", "Zygote"] diff --git a/src/Kronecker.jl b/src/Kronecker.jl index 453f29b..1b465ee 100644 --- a/src/Kronecker.jl +++ b/src/Kronecker.jl @@ -14,6 +14,7 @@ import LinearAlgebra: mul!, lmul!, rmul!, pinv, ldiv! import Base: collect, *, getindex, size, eltype, inv, adjoint using SparseArrays using LinearAlgebra: checksquare +using ChainRulesCore include("base.jl") include("kroneckerpowers.jl") @@ -25,5 +26,6 @@ include("eigen.jl") include("factorization.jl") include("kroneckergraphs.jl") include("names.jl") +include("chainrules.jl") end # module diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 0000000..3407932 --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,19 @@ +function ChainRulesCore.frule((_, ΔA, ΔB), ::typeof(KroneckerProduct), A::AbstractMatrix, B::AbstractMatrix) + Ω = (A ⊗ B) + ∂Ω = (ΔA ⊗ B) + (A ⊗ ΔB) + return Ω, ∂Ω +end + +function ChainRulesCore.rrule(::typeof(KroneckerProduct), A::AbstractMatrix, B::AbstractMatrix) + function kronecker_product_pullback(ΔΩ) + nA = size(A, 2) + IA_col = Diagonal(ones(nA)) + ∂A = ΔΩ * (IA_col ⊗ B') + + nB = size(B, 2) + IB_col = Diagonal(ones(nB)) + ∂B = ΔΩ * (A' ⊗ IB_col) + return (NO_FIELDS, ∂A, ∂B) + end + return (A ⊗ B), kronecker_product_pullback +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index ef1d397..c283888 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Kronecker, Test, LinearAlgebra, Random, FillArrays using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC, sprand, sparse, issparse +using Zygote: gradient @testset "Kronecker" begin include("testbase.jl") @@ -12,4 +13,5 @@ using SparseArrays: AbstractSparseMatrix, SparseMatrixCSC, sprand, include("testkroneckersum.jl") include("testfactorization.jl") include("testkroneckergraphs.jl") + include("testchainrules.jl") end diff --git a/test/testchainrules.jl b/test/testchainrules.jl new file mode 100644 index 0000000..573bc1c --- /dev/null +++ b/test/testchainrules.jl @@ -0,0 +1,70 @@ +function test_gradients_for_KroneckerProduct(M_out, N_out) + M, N = 3, 2 + n_samples = 3 + A = rand(N_out, N) + B = rand(M_out, M) + x = rand(M*N, n_samples) + y = rand(M_out*N_out, n_samples) + + eager_model(A, B, X) = kron(A, B) * X + + function loss(A, B, X) + Z = eager_model(A, B, X) - y + L = 0.5 * tr(Z' * Z) + return L + end + + lazy_model(A, B, X) = (A ⊗ B) * X + + function lazy_loss(A, B, X) + Z = lazy_model(A, B, X) - y + L = 0.5 * tr(Z' * Z) + return L + end + + function gradient_A(A, B, x) + Z = eager_model(A, B, x) - y + m, n = size(A) + IA = Diagonal(ones(m*n)) + return Z * (kron(IA', B) * x)' + end + + function gradient_B(A, B, x) + Z = eager_model(A, B, x) - y + m, n = size(B) + IB = Diagonal(ones(m*n)) + return Z * (kron(A, IB) * x)' + end + + function gradient_x(A, B, x) + Z = eager_model(A, B, x) - y + return kron(A, B)'*Z + end + + if (M_out, N_out) == (1,1) + @testset "Gradients for M_out=$M_out, N_out=$N_out" begin + gA, gB, gx = gradient(loss, A, B, x) + # Compare hand-written gradients with running Zygote.gradient on the loss function + @test gradient_A(A, B, x) ≈ gA + @test gradient_B(A, B, x) ≈ gB + @test gradient_x(A, B, x) ≈ gx + # Compare `Base.kron` with `Kronecker.kronecker` in Zygote + @test all(gradient(loss, A, B, x) .≈ gradient(lazy_loss, A, B, x)) + end + else + @testset "Gradients for M_out=$M_out, N_out=$N_out" begin + gA, gB, gx = gradient(loss, A, B, x) + # Compare hand-written gradients with running Zygote.gradient on the loss function + @test_broken gradient_A(A, B, x) ≈ gA + @test_broken gradient_B(A, B, x) ≈ gB + @test gradient_x(A, B, x) ≈ gx + # Compare `Base.kron` with `Kronecker.kronecker` in Zygote + @test_broken all(gradient(loss, A, B, x) .≈ gradient(lazy_loss, A, B, x)) + end + end +end + +# factors A and B in (A⊗B)*x : [M_out*N_out, n_samples=3] +for (Mo, No) in ((1,1), (2, 3)) + test_gradients_for_KroneckerProduct(Mo, No) +end