Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

norm(NoTangent()) causes StackOverflow #639

Closed
lukas-weber opened this issue Dec 4, 2023 · 8 comments · Fixed by #642
Closed

norm(NoTangent()) causes StackOverflow #639

lukas-weber opened this issue Dec 4, 2023 · 8 comments · Fixed by #642
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@lukas-weber
Copy link
Contributor

using ChainRulesCore
norm(NoTangent())

leads to

ERROR: StackOverflowError:
Stacktrace:
     [1] norm(itr::NoTangent)
       @ LinearAlgebra ~/packages/julias/julia-1.9/share/julia/stdlib/v1.9/LinearAlgebra/src/generic.jl:596
     [2] (::Base.MappingRF{typeof(norm), Base.BottomRF{typeof(max)}})(acc::Base._InitialValue, x::NoTangent)
       @ Base ./reduce.jl:95
     [3] _foldl_impl(op::Base.MappingRF{typeof(norm), Base.BottomRF{typeof(max)}}, init::Base._InitialValue, itr::NoTangent)
       @ Base ./reduce.jl:58
     [4] foldl_impl(op::Base.MappingRF{typeof(norm), Base.BottomRF{typeof(max)}}, nt::Base._InitialValue, itr::NoTangent)
       @ Base ./reduce.jl:48
     [5] mapfoldl_impl(f::typeof(norm), op::typeof(max), nt::Base._InitialValue, itr::NoTangent)
       @ Base ./reduce.jl:44
     [6] mapfoldl(f::Function, op::Function, itr::NoTangent; init::Base._InitialValue)
       @ Base ./reduce.jl:170
     [7] mapfoldl
       @ ./reduce.jl:170 [inlined]
     [8] #mapreduce#292
       @ ./reduce.jl:302 [inlined]
     [9] mapreduce(f::Function, op::Function, itr::NoTangent)
       @ Base ./reduce.jl:302
    [10] generic_normInf(x::NoTangent)
       @ LinearAlgebra ~/packages/julias/julia-1.9/share/julia/stdlib/v1.9/LinearAlgebra/src/generic.jl:453
    [11] normInf(x::NoTangent)
       @ LinearAlgebra ~/packages/julias/julia-1.9/share/julia/stdlib/v1.9/LinearAlgebra/src/generic.jl:527
    [12] generic_norm2(x::NoTangent)
       @ LinearAlgebra ~/packages/julias/julia-1.9/share/julia/stdlib/v1.9/LinearAlgebra/src/generic.jl:463
    [13] norm2
       @ ~/packages/julias/julia-1.9/share/julia/stdlib/v1.9/LinearAlgebra/src/generic.jl:529 [inlined]
    [14] norm(itr::NoTangent, p::Int64)
       @ LinearAlgebra ~/packages/julias/julia-1.9/share/julia/stdlib/v1.9/LinearAlgebra/src/generic.jl:598
--- the last 14 lines are repeated 5712 more times ---

I encountered this while implementing a custom rrule for an integration routine. Then, integrating the tangent would check its norm for convergence, spilling the beans. Maybe this should error or return NaN instead.

@mcabbott mcabbott added the bug Something isn't working label Dec 4, 2023
@oxinabox
Copy link
Member

oxinabox commented Dec 7, 2023

I believe this should just return NoTangent() because norm is a linear operator.

@oxinabox oxinabox added the good first issue Good for newcomers label Dec 7, 2023
@sethaxen
Copy link
Member

sethaxen commented Dec 7, 2023

norm is not a linear map. e.g. iszero(norm(x) + norm(-x)) only when iszero(x)

@oxinabox
Copy link
Member

oxinabox commented Dec 8, 2023

Fair point.
On that basis, why are you calling norm on a tangent?
Since it is not a linear map, that suggests your pullback (or pushforward) is incorrect.
Since pullbacks (and pushforwards) are always linear maps on their inputs.

@lukas-weber
Copy link
Contributor Author

Hmm, it is an pullback for an integrator which I implement by calling the pullback of the integrand and integrating the resulting tangent using the integrator. The integrator checks convergence by calling norm on the tangent. So apart from this convergence check, the pullback is linear in its inputs.

In other words, when summing an infinite series of tangents, it’s good to have a way to tell when the series is converged and I guess for that you need a notion of distance. For tangents without any NoTangent entries hidden inside, this distance is induced by the current implementation of norm. It might sense to define that distance also between tangents that somewhere in them have a NoTangent entry.

Leaving everything else the same, defining norm(::NoTangent) = 0 would give the correct distance for the application above, but I’m not sure if it’s a good idea in general.

I think norm(::NoTangent) = NoTangent() as you proposed is also sensible as it would satisfy norm(NoTangent())^2 === NoTangent()*NoTangent() === dot(NoTangent(), NoTangent()). Although I wonder if dot(NoTangent(), NoTangent()) returning NoTangent() as it does currently is the best definition, because the result of the scalar product should not be a member of the tangent space.

@sethaxen
Copy link
Member

sethaxen commented Dec 9, 2023

@lukas-weber are you able to share your pullback code with us?

@lukas-weber
Copy link
Contributor Author

lukas-weber commented Dec 9, 2023

Sure, here is a minimal example using QuadGK as the integrator.

using ChainRulesCore
using Zygote
using QuadGK
using LinearAlgebra

struct TangentWrapper{T}
    tangent::T
end

Base.:(+)(a::TangentWrapper, b::TangentWrapper) = TangentWrapper(a.tangent + b.tangent)
Base.:(-)(a::TangentWrapper, b::TangentWrapper) = TangentWrapper(a.tangent + -b.tangent)
Base.:(*)(t::TangentWrapper, f::Number) = TangentWrapper(f * t.tangent)
LinearAlgebra.norm(t::TangentWrapper) = norm(t.tangent)

function integrate(func)
    return quadgk(x->exp(-x^2) * func(x), -Inf, Inf)[1]
end

function ChainRulesCore.rrule(config::RuleConfig, ::typeof(integrate), func)
    y = integrate(func)
    project = ProjectTo(func)

    function integrate_pullback(Δy)
        function dfunc_integrand(x)
            _, inner_rrule = ChainRulesCore.rrule_via_ad(config, func, x)
            return TangentWrapper(inner_rrule(Δy)[1])
        end

        return NoTangent(), @thunk(project(integrate(dfunc_integrand).tangent))
    end

    return y, integrate_pullback
end


function test()
    a = 10
    b = 1.0
    has_no_tangent(x) = sum(fill(x, a))

    example_func(x) = cos(b * x) + has_no_tangent(x)

    @show integrate(example_func)
    @show gradient(f->integrate(f), example_func)
end

The TangentWrapper type was introduced to provide the binary - method which was also missing.

As such, the code works as long as the AD provided tangent of example_func does not contain any NoTangents down its tree. The sum(fill(x,a)) closure is some strange thing designed to achieve that. In practice it happened to me with a more complicated integrand operating on structured types.

(This was inspired by a similar rrule that was already implemented in Integrals.jl, but not fit for my use case)

@oxinabox
Copy link
Member

I think this makes sense, norm is often used in exactly this way.
I think we can add this overload.
At least for AbstractZero subtypes and NotImplemented
Feel encourages to make a PR.

@lukas-weber
Copy link
Contributor Author

Okay, I’ll take a stab at it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants