-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enter/exit_autocast of torch.amp.autocast_mode #824
Comments
Assigning to me until I can fill out a better reproducer |
Thunder doesn't support PyTorch's context managers like autocast, no_grad, enable_grad, etc. inside the compiled function. With the recent addition of the autocast-specific dispatch at tracing time (#705, #810) supporting this might not take a lot of work, the challenge is not to reorder these enter and exit calls inappropriately. |
I think this could be the minimal repro: import torch
import thunder
class ThunderJitBackend:
def __init__(self, **compile_options) -> None:
self.thunder_jit_fns = []
self.dynamo_graphs = []
self.cnt = 0
self.compile_options = compile_options
def compile(self, gm, sample_args):
self.dynamo_graphs.append(gm)
gm.real_recompile()
thunder_jit_fn = thunder.jit(gm, **self.compile_options)
self.thunder_jit_fns.append(thunder_jit_fn)
self.cnt += 1
return thunder_jit_fn
dev = "cuda"
def foo(x):
with torch.autocast(dev, torch.bfloat16):
y = x @ x
return x + 2, y
with torch.device(dev):
model = foo
x = torch.randn(16, 16)
args = (x,)
kwargs = {}
jit_backend = ThunderJitBackend()
cmodel = torch.compile(model, backend=jit_backend.compile)
o = cmodel(*args, **kwargs)
print(f"GRAPHS {jit_backend.cnt}")
print(jit_backend.dynamo_graphs[0])
for tfn in jit_backend.thunder_jit_fns:
print(thunder.last_traces(tfn)[-1])
torch.testing.assert_close(o, model(*args, **kwargs)) Dynamo Graph def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
_enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', torch.bfloat16, True, None)
y = l_x_ @ l_x_
_exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = None
add = l_x_ + 2; l_x_ = None
return (add, y) With dev="cuda", we see the following error File "/home/kkalambarkar/git/pytorch/torch/_dynamo/eval_frame.py", line 410, in _fn
return fn(*args, **kwargs)
File "/home/kkalambarkar/git/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/kkalambarkar/git/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/kkalambarkar/lightning-thunder/thunder/core/module.py", line 61, in forward
res = self._forward_fn(*args, **kwargs)
File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 685, in fn_
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 225, in cache_info_wrapper
res = fn(*args, **kwargs)
File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 506, in get_computation_and_inputs
jit_results: TraceResults = interpreter(
File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 213, in _general_frontend
return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
File "/home/kkalambarkar/lightning-thunder/thunder/core/jit_ext.py", line 1768, in thunder_general_jit
result = jfn(*args, **kwargs)
File "/home/kkalambarkar/lightning-thunder/thunder/core/interpreter.py", line 6760, in fn_
raise InterpreterError(msg) from e
thunder.core.interpreter.InterpreterError: Encountered exception TypeError: unhashable type: 'instancemethod' while tracing GraphModule() With dev="cpu", the program compiles but silently ignores the autocast in the function (computes in single precision), failing at torch.testing.assert_close. |
Removing my assignment because Kshiteej is a hero w.r.t. finding minimal reproducers 😄. Thank you |
Yes, this is a great idea for the interim. I'll see if these are in actionable parts of the code (NeMo, or maybe megatron). |
Some automation revealed an even simpler reproducer: import torch
import thunder
class DynamoModule(torch.nn.Module):
def forward(self):
_enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', torch.bfloat16, True, None)
_exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None
return ()
inputs = [
]
fqn = thunder.jit(DynamoModule())
fqn(*inputs) |
🚀 Model / language coverage
I'm trying to get a fuller picture of what we need to support NeVA. As such I'm using:
(thanks Ivan for the great idea!)
And one of the issues that gets reported is e.g.
Pitch
This looks like it is going to be important for #343.
Alternatives / Potential work-arounds
Minimal Repro
cc @apaz-cli @crcrpar @tfogal
The text was updated successfully, but these errors were encountered: