-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dtype mismatch in linear layer #678
Comments
we could likely print |
I am seeing a different error related to advanced indexing. [rank0]: File "thunder/core/proxies.py", line 1333, in __getitem__
[rank0]: return method(self, key)
[rank0]: File "thunder/core/symbol.py", line 268, in __call__
[rank0]: result = self.meta(*args, **kwargs)
[rank0]: File "thunder/core/langctxs.py", line 132, in _fn
[rank0]: result = fn(*args, **kwargs)
[rank0]: File "thunder/torch/__init__.py", line 890, in getitem
[rank0]: return clang.getitem(a, key)
[rank0]: File "thunder/core/langctxs.py", line 132, in _fn
[rank0]: result = fn(*args, **kwargs)
[rank0]: File "thunder/clang/__init__.py", line 868, in getitem
[rank0]: return _advanced_indexing(a, key)
[rank0]: File "thunder/core/langctxs.py", line 132, in _fn
[rank0]: result = fn(*args, **kwargs)
[rank0]: File "thunder/clang/__init__.py", line 729, in _advanced_indexing
[rank0]: utils.check(
[rank0]: File "thunder/core/baseutils.py", line 103, in check
[rank0]: raise exception_type(s())
[rank0]: RuntimeError: Advanced indexing currently only supports zero or one-dimensional integer tensors, but found a tensor with dtype int64 and 2 dimensions thunder commit used - 72e033a Full Log: neva.log |
triage review
|
I have been able to repro the failure with an independent script. The failure happens due to the interaction of autocast and mixed input dtypes. import thunder
import torch
def foo(x, w):
return torch.nn.functional.linear(x, w)
device = torch.device("cuda")
with device:
# Mixed input types.
x, w = torch.randn(16, 16, dtype=torch.bfloat16), torch.randn(16, 16)
# Same input types (works with thunder)
# x, w = torch.randn(16, 16), torch.randn(16, 16)
print(x.dtype, w.dtype)
with torch.autocast("cuda", torch.bfloat16):
# Eager autocast handles mixed input types.
eager_out = foo(x, w)
# `thunder.jit` doesn't handle mixed inputs.
jfoo = thunder.jit(foo)
jit_out = jfoo(x, w)
print(thunder.last_traces(jfoo)[-1])
torch.testing.assert_close(eager_out, jit_out) |
Great! Thank you, excellent work :-) |
The reason it fails currently is because, while tracing with thunder.jit -
With mixed input dtypes, we fail at step 1 as these operators don't allow mixed inputs. (In eager, with the context manager active, dispatcher first applies the conversion before passing the converted inputs to the operators). Potential Fix:
@t-vi I would like your opinion on the same and some pointers. Thank you! |
Great analysis @kshitij12345 ! For 1: We do have autocast handling in thunder.jit and cache_info. lightning-thunder/thunder/__init__.py Line 395 in da23a0b
For 2: To my mind, this is a thunder.torch thing more than something specific to lightning-thunder/thunder/torch/__init__.py Line 108 in da23a0b
WDYT? |
🐛 Bug
Full log of the run that includes the unabbreviated traceback.
To Reproduce
Note you'll need the referenced
./data
directory.Expected behavior
Environment
cc @crcrpar @tfogal
The text was updated successfully, but these errors were encountered: