Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cat: support inputs with mixed dtypes #819

Merged
merged 4 commits into from
Jul 23, 2024

Conversation

kshitij12345
Copy link
Collaborator

Fixes : #812

TODO - Look into test_vjp_correctness failure

@kshitij12345 kshitij12345 changed the title cat: support inputs with mixed dtypes [WIP] cat: support inputs with mixed dtypes Jul 22, 2024
@kshitij12345
Copy link
Collaborator Author

test_vjp_correctness failure : check_vjp verifies the output against numerically computed Jacobian. However, with mixed inputs dtypes test (float and double) input, the numerical differentiation output is slightly different (for the float part of inputs) leading to mismatch. I would think that this is expected and we should probably just increase the tolerance for the test (or maybe if there is a way to increase tolerance only for this sample).

cc: @IvanYashchuk (original author of test)

Smaller Repro

from functools import partial
import torch
import thunder
from thunder.tests.make_tensor import make_tensor, make_tensor_like
from thunder.core.pytree import tree_map, tree_flatten
from thunder.tests.test_grad import numerical_jvp, _dot, vjp as thunder_vjp, Sequence, flatten_func, _make_differentiable_wrapper


# Copied from `test_grad` with minor changes to support torch eager
# and to create `u` and `v` as ones tensor (for easier reasoning).
def check_vjp(f, *primals, comp, eager=False):
    """Check that the vector-Jacobian product of a function is correct.

    Args:
        f (callable): The function to differentiate.
        *primals (torch.Tensor): The input tensors.
        executor (str): The executor to use. Defaults to "torch".
        atol (float): Absolute tolerance. Defaults to None.
        rtol (float): Relative tolerance. Defaults to None.

    Raises:
        AssertionError: If the vector-Jacobian product is not correct.
    """
    # Let f be a function from vectors of size n to vectors of size m.
    # Its Jacobian is a matrix J of size m x n.
    # The adjoint property is J^* J = I, where J^* is the conjugate transpose (adjoint) of J.
    # J^* is a matrix of size n x m.
    # For any vector v of size m, J^* v is a vector of size n.
    # For any vector u of size n, J u is a vector of size m.
    # The dot product of J^* v and u is the same as the dot product of v and J u.
    # This function checks that the dot product of J^* v and u is the same as the dot product of v and J u.
    # 〈J u, v〉 == 〈u, J* v〉
    # Since u and v can be arbitrary, we take u = rand_like(primals), and v = rand_like(f(primals)).
    # We compute J u using numerical_jvp, and J* v using Thunder's vjp. That way we check correctness of Thunder's vjp.
    # Using finite differences we can compute J u, but we can't compute J* v, without computing full J, which is expensive.

    # make = partial(make_tensor_like, low=0, high=1)
    make = partial(torch.ones_like)

    u = tree_map(make, primals)
    if eager:
      outs_p, J_u = numerical_jvp(f)(primals, u)
    else:
      outs_p, J_u = numerical_jvp(thunder.compile(f, disable_torch_autograd_support=True, disable_preprocessing=True))(primals, u)

    multiple_results = isinstance(outs_p, Sequence)

    v = tree_map(make, outs_p)
    if eager:
      _, J_star_v = torch.autograd.functional.vjp(f, primals, v)
    else:
      _, J_star_v = thunder.compile(thunder_vjp(f), disable_torch_autograd_support=True, disable_preprocessing=True)(primals, v)

    if not multiple_results:
        v = (v,)
        J_u = (J_u,)

    J_u_v = _dot(J_u, v)
    u_J_star_v = _dot(u, J_star_v)
    if J_u_v.isnan().any():
        # TODO: find a better way to handle NaNs in finite differences
        return  # skip this sample
    comp(J_u_v, u_J_star_v)


# Check thunder - torch.cat
f = thunder.torch.cat
primals = ((torch.ones(1, requires_grad=True), torch.ones(1, dtype=torch.double, requires_grad=True)),)
kwargs = {"dim": 0}
flat_op, flat_args, spec = flatten_func(f, primals, kwargs)

filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args)
# AssertionError: Scalars are not close!

# Expected 2.0 but got 1.999905726350911.
# Absolute difference: 9.427364908898284e-05 (up to 1e-07 allowed)
# Relative difference: 4.713682454449142e-05 (up to 1e-07 allowed)
check_vjp(filtered_op, *filtered_args, comp=torch.testing.assert_close)

# Check eager - torch.cat
f = torch.cat
flat_op, flat_args, spec = flatten_func(f, primals, kwargs)
filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args)

# AssertionError: Scalars are not close!

# Expected 2.0 but got 1.999905726350911.
# Absolute difference: 9.427364908898284e-05 (up to 1e-07 allowed)
# Relative difference: 4.713682454449142e-05 (up to 1e-07 allowed)
check_vjp(filtered_op, *filtered_args, comp=torch.testing.assert_close, eager=True)

@kshitij12345 kshitij12345 changed the title [WIP] cat: support inputs with mixed dtypes cat: support inputs with mixed dtypes Jul 23, 2024
@kshitij12345 kshitij12345 marked this pull request as ready for review July 23, 2024 12:03
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great.

@t-vi
Copy link
Collaborator

t-vi commented Jul 23, 2024

Thank you @kshitij12345

@t-vi t-vi merged commit fda3fce into Lightning-AI:main Jul 23, 2024
39 checks passed
@github-actions github-actions bot deleted the cat-upcast branch October 23, 2024 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dtype mismatch in cat: bfloat16 and float16
2 participants