-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cat: support inputs with mixed dtypes #819
Conversation
test_vjp_correctness failure : cc: @IvanYashchuk (original author of test) Smaller Repro from functools import partial
import torch
import thunder
from thunder.tests.make_tensor import make_tensor, make_tensor_like
from thunder.core.pytree import tree_map, tree_flatten
from thunder.tests.test_grad import numerical_jvp, _dot, vjp as thunder_vjp, Sequence, flatten_func, _make_differentiable_wrapper
# Copied from `test_grad` with minor changes to support torch eager
# and to create `u` and `v` as ones tensor (for easier reasoning).
def check_vjp(f, *primals, comp, eager=False):
"""Check that the vector-Jacobian product of a function is correct.
Args:
f (callable): The function to differentiate.
*primals (torch.Tensor): The input tensors.
executor (str): The executor to use. Defaults to "torch".
atol (float): Absolute tolerance. Defaults to None.
rtol (float): Relative tolerance. Defaults to None.
Raises:
AssertionError: If the vector-Jacobian product is not correct.
"""
# Let f be a function from vectors of size n to vectors of size m.
# Its Jacobian is a matrix J of size m x n.
# The adjoint property is J^* J = I, where J^* is the conjugate transpose (adjoint) of J.
# J^* is a matrix of size n x m.
# For any vector v of size m, J^* v is a vector of size n.
# For any vector u of size n, J u is a vector of size m.
# The dot product of J^* v and u is the same as the dot product of v and J u.
# This function checks that the dot product of J^* v and u is the same as the dot product of v and J u.
# 〈J u, v〉 == 〈u, J* v〉
# Since u and v can be arbitrary, we take u = rand_like(primals), and v = rand_like(f(primals)).
# We compute J u using numerical_jvp, and J* v using Thunder's vjp. That way we check correctness of Thunder's vjp.
# Using finite differences we can compute J u, but we can't compute J* v, without computing full J, which is expensive.
# make = partial(make_tensor_like, low=0, high=1)
make = partial(torch.ones_like)
u = tree_map(make, primals)
if eager:
outs_p, J_u = numerical_jvp(f)(primals, u)
else:
outs_p, J_u = numerical_jvp(thunder.compile(f, disable_torch_autograd_support=True, disable_preprocessing=True))(primals, u)
multiple_results = isinstance(outs_p, Sequence)
v = tree_map(make, outs_p)
if eager:
_, J_star_v = torch.autograd.functional.vjp(f, primals, v)
else:
_, J_star_v = thunder.compile(thunder_vjp(f), disable_torch_autograd_support=True, disable_preprocessing=True)(primals, v)
if not multiple_results:
v = (v,)
J_u = (J_u,)
J_u_v = _dot(J_u, v)
u_J_star_v = _dot(u, J_star_v)
if J_u_v.isnan().any():
# TODO: find a better way to handle NaNs in finite differences
return # skip this sample
comp(J_u_v, u_J_star_v)
# Check thunder - torch.cat
f = thunder.torch.cat
primals = ((torch.ones(1, requires_grad=True), torch.ones(1, dtype=torch.double, requires_grad=True)),)
kwargs = {"dim": 0}
flat_op, flat_args, spec = flatten_func(f, primals, kwargs)
filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args)
# AssertionError: Scalars are not close!
# Expected 2.0 but got 1.999905726350911.
# Absolute difference: 9.427364908898284e-05 (up to 1e-07 allowed)
# Relative difference: 4.713682454449142e-05 (up to 1e-07 allowed)
check_vjp(filtered_op, *filtered_args, comp=torch.testing.assert_close)
# Check eager - torch.cat
f = torch.cat
flat_op, flat_args, spec = flatten_func(f, primals, kwargs)
filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args)
# AssertionError: Scalars are not close!
# Expected 2.0 but got 1.999905726350911.
# Absolute difference: 9.427364908898284e-05 (up to 1e-07 allowed)
# Relative difference: 4.713682454449142e-05 (up to 1e-07 allowed)
check_vjp(filtered_op, *filtered_args, comp=torch.testing.assert_close, eager=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great.
Thank you @kshitij12345 |
Fixes : #812
TODO - Look into
test_vjp_correctness
failure