-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tangent_arithmetic: add norm
for NoTangent
, ZeroTangent
and NotImplemented
#642
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #642 +/- ##
==========================================
- Coverage 93.90% 93.79% -0.12%
==========================================
Files 15 14 -1
Lines 903 886 -17
==========================================
- Hits 848 831 -17
Misses 55 55 ☔ View full report in Codecov by Sentry. |
src/tangent_arithmetic.jl
Outdated
@@ -19,6 +19,7 @@ Notice: | |||
Base.:+(x::NotImplemented, ::NotImplemented) = x | |||
Base.:*(x::NotImplemented, ::NotImplemented) = x | |||
LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x | |||
LinearAlgebra.norm(x::NotImplemented, p::Real=2) = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be:
LinearAlgebra.norm(x::NotImplemented, p::Real=2) = 0 | |
LinearAlgebra.norm(x::NotImplemented, p::Real=2) = throw(NotImplementedException(x)) |
Or
LinearAlgebra.norm(x::NotImplemented, p::Real=2) = 0 | |
LinearAlgebra.norm(x::NotImplemented, p::Real=2) = x |
if we want to defer it to later (when presumably convert
or something will get called on it?)
And it should probably be in the src/tangent_types/not_implemented.jl
Since the point of calling norm
on tangents is to assess their magnitude for convergence checks.
and you can't do a convergence check if one of those tangents was not implemented so has an unknown value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, suggestion 1 is already the current behavior on main.
In the gradient descent use case, I would agree, but in the integrating tangents case, where you are interested in the norm of the difference of two tangents, that would break.
Assume inside the integrand pullback you have a tangent with one of its members being NotImplemented
. Then, the expected result would be the integrated tangent with that same member NotImplemented
.
But if you check the norm of the difference of two such tangents to see if you have converged to that expected result, suggestion 1 will throw and suggestion 2 will return a NotImplemented
that will win against all other components of the tangent and return as the compound norm, making the convergence check impossible.
So zero is still the definition to make it work as expected in that case. In the gradient descent use case, you would throw anyway as soon as you add a partially NotImplemented
tangent to the minimizer.
I guess the mathematically clean way is to implement some distance(tangent1, tangent2)
that is norm(tangent1 - tangent2)
if they are fully implemented and zero if they are equal up to NotImplemented
. But maybe just returning 0 is okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
problem is we can't be sure two NotImplemented
are equal either, in general.
I think in that case needs the user to add in that extra info by preprocessing before calling norm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
I agree with your logic. tangent*tangent
shouldn't really be implemented, and dot
should return a scalar. it is left over from a much older version. Feel welcome to open an issue about this.
Before we can merge a few things:
- fix the formatting problems identify by ReviewDog
- fix the issues i highlighed with
NotImplemented
I removed the implementation for |
Fixes #639.
norm
now returns 0 for the named tangents.I noticed, however, that the similar implementations of
dot(tangent, tangent)
ortangent*tangent
returnZeroTangent()
instead. To my understanding, returning a tangent in those cases is never correct (probablytangent*tangent
should not be defined at all). Probably that’s a different issue, but here it is the rationale for sticking with 0 fornorm
.Please let me know if there is anything I can do to improve this!