-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
norm(NoTangent())
causes StackOverflow
#639
Comments
I believe this should just return |
|
Fair point. |
Hmm, it is an pullback for an integrator which I implement by calling the pullback of the integrand and integrating the resulting tangent using the integrator. The integrator checks convergence by calling In other words, when summing an infinite series of tangents, it’s good to have a way to tell when the series is converged and I guess for that you need a notion of distance. For tangents without any Leaving everything else the same, defining I think |
@lukas-weber are you able to share your pullback code with us? |
Sure, here is a minimal example using QuadGK as the integrator. using ChainRulesCore
using Zygote
using QuadGK
using LinearAlgebra
struct TangentWrapper{T}
tangent::T
end
Base.:(+)(a::TangentWrapper, b::TangentWrapper) = TangentWrapper(a.tangent + b.tangent)
Base.:(-)(a::TangentWrapper, b::TangentWrapper) = TangentWrapper(a.tangent + -b.tangent)
Base.:(*)(t::TangentWrapper, f::Number) = TangentWrapper(f * t.tangent)
LinearAlgebra.norm(t::TangentWrapper) = norm(t.tangent)
function integrate(func)
return quadgk(x->exp(-x^2) * func(x), -Inf, Inf)[1]
end
function ChainRulesCore.rrule(config::RuleConfig, ::typeof(integrate), func)
y = integrate(func)
project = ProjectTo(func)
function integrate_pullback(Δy)
function dfunc_integrand(x)
_, inner_rrule = ChainRulesCore.rrule_via_ad(config, func, x)
return TangentWrapper(inner_rrule(Δy)[1])
end
return NoTangent(), @thunk(project(integrate(dfunc_integrand).tangent))
end
return y, integrate_pullback
end
function test()
a = 10
b = 1.0
has_no_tangent(x) = sum(fill(x, a))
example_func(x) = cos(b * x) + has_no_tangent(x)
@show integrate(example_func)
@show gradient(f->integrate(f), example_func)
end The As such, the code works as long as the AD provided tangent of (This was inspired by a similar rrule that was already implemented in Integrals.jl, but not fit for my use case) |
I think this makes sense, norm is often used in exactly this way. |
Okay, I’ll take a stab at it. |
leads to
I encountered this while implementing a custom rrule for an integration routine. Then, integrating the tangent would check its norm for convergence, spilling the beans. Maybe this should error or return NaN instead.
The text was updated successfully, but these errors were encountered: