Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tangent_arithmetic: add norm for NoTangent, ZeroTangent and NotImplemented #642

Merged
merged 1 commit into from
Dec 18, 2023
Merged

tangent_arithmetic: add norm for NoTangent, ZeroTangent and NotImplemented #642

merged 1 commit into from
Dec 18, 2023

Conversation

lukas-weber
Copy link
Contributor

Fixes #639.

norm now returns 0 for the named tangents.

I noticed, however, that the similar implementations of dot(tangent, tangent) or tangent*tangent return ZeroTangent() instead. To my understanding, returning a tangent in those cases is never correct (probably tangent*tangent should not be defined at all). Probably that’s a different issue, but here it is the rationale for sticking with 0 for norm.

Please let me know if there is anything I can do to improve this!

Copy link

codecov bot commented Dec 14, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (cb1aa6b) 93.90% compared to head (e63f62b) 93.79%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #642      +/-   ##
==========================================
- Coverage   93.90%   93.79%   -0.12%     
==========================================
  Files          15       14       -1     
  Lines         903      886      -17     
==========================================
- Hits          848      831      -17     
  Misses         55       55              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -19,6 +19,7 @@ Notice:
Base.:+(x::NotImplemented, ::NotImplemented) = x
Base.:*(x::NotImplemented, ::NotImplemented) = x
LinearAlgebra.dot(x::NotImplemented, ::NotImplemented) = x
LinearAlgebra.norm(x::NotImplemented, p::Real=2) = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be:

Suggested change
LinearAlgebra.norm(x::NotImplemented, p::Real=2) = 0
LinearAlgebra.norm(x::NotImplemented, p::Real=2) = throw(NotImplementedException(x))

Or

Suggested change
LinearAlgebra.norm(x::NotImplemented, p::Real=2) = 0
LinearAlgebra.norm(x::NotImplemented, p::Real=2) = x

if we want to defer it to later (when presumably convert or something will get called on it?)

And it should probably be in the src/tangent_types/not_implemented.jl

Since the point of calling norm on tangents is to assess their magnitude for convergence checks.
and you can't do a convergence check if one of those tangents was not implemented so has an unknown value

Copy link
Contributor Author

@lukas-weber lukas-weber Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, suggestion 1 is already the current behavior on main.

In the gradient descent use case, I would agree, but in the integrating tangents case, where you are interested in the norm of the difference of two tangents, that would break.

Assume inside the integrand pullback you have a tangent with one of its members being NotImplemented. Then, the expected result would be the integrated tangent with that same member NotImplemented.

But if you check the norm of the difference of two such tangents to see if you have converged to that expected result, suggestion 1 will throw and suggestion 2 will return a NotImplemented that will win against all other components of the tangent and return as the compound norm, making the convergence check impossible.

So zero is still the definition to make it work as expected in that case. In the gradient descent use case, you would throw anyway as soon as you add a partially NotImplemented tangent to the minimizer.

I guess the mathematically clean way is to implement some distance(tangent1, tangent2) that is norm(tangent1 - tangent2) if they are fully implemented and zero if they are equal up to NotImplemented. But maybe just returning 0 is okay.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

problem is we can't be sure two NotImplemented are equal either, in general.
I think in that case needs the user to add in that extra info by preprocessing before calling norm

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.
I agree with your logic. tangent*tangent shouldn't really be implemented, and dot should return a scalar. it is left over from a much older version. Feel welcome to open an issue about this.

Before we can merge a few things:

  • fix the formatting problems identify by ReviewDog
  • fix the issues i highlighed with NotImplemented

@lukas-weber
Copy link
Contributor Author

I removed the implementation for NotImplemented for now, so that the behavior is like in suggestion 1. I also fixed the formatting and moved the definition for norm(::AbstractZero) to abstract_zero.jl.

@oxinabox oxinabox merged commit 86a3256 into JuliaDiff:main Dec 18, 2023
18 of 27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

norm(NoTangent()) causes StackOverflow
2 participants