-
Hi! First, I want to say that I've greatly benefited from being able to use I have a question about intended behavior in certain corner cases. In cases where the "subdifferential" of a function contains "more than one right answer," what is permissible behavior for Jax? I've noticed a small inconsistency in Jax's behavior for cases like this: from jax import numpy as np
from jax import grad
x, y = 0.0, 0.0
# jax breaks "ties" for grad(abs)(0.0)
def f(x, y):
return np.abs(x - y)
# but not for grad(norm)(something_with_zero_norm)
def g(x, y):
return np.linalg.norm(x - y)
# (+1, -1), (nan, nan)
grad(f, argnums=(0,1))(x, y), grad(g, argnums=(0,1))(x, y) In the first case, it appears permissible to return an arbitrary point in the subdifferential, and in the second case it appears impermissible. I imagine the intended behavior might be to return Alternately, it might be intended for Jax to return an arbitrary value in interval [-1, +1] for the first case, and an arbitrary value in unit ball for the second case, but I don't know how safe that is. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
I believe this comes from the fact that grad(np.sqrt)(0.0)
# DeviceArray(inf, dtype=float32) I believe this is the only sensible derivative of sqrt for this value, so On the other hand, the gradient of In any case, the "correct" answer here is ambiguous, and if your code depends on a particular convention, the best option would be to define a custom gradient for the function in question; see https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html. |
Beta Was this translation helpful? Give feedback.
I believe this comes from the fact that
norm(x)
involves a square root, which has an infinite gradient atx=0
:I believe this is the only sensible derivative of sqrt for this value, so
NaN
is the only sensible derivative ofnorm
at zero as implemented.On the other hand, the gradient of
abs
is currently defined as1
atx=0
, which is sensible in the sense that it is the correct right-limit, although the left limit is-1
. It may also be sensible to returnNaN
for the gradient ofabs
at zero, but I'm not certain what implications that might have in other places whereabs
is used.In any case, the "correct" answer here is ambiguous, and if …