Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enter/exit_autocast of torch.amp.autocast_mode #824

Open
tfogal opened this issue Jul 22, 2024 · 6 comments
Open

enter/exit_autocast of torch.amp.autocast_mode #824

tfogal opened this issue Jul 22, 2024 · 6 comments
Labels
amp dynamo nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage triage review

Comments

@tfogal
Copy link
Collaborator

tfogal commented Jul 22, 2024

🚀 Model / language coverage

I'm trying to get a fuller picture of what we need to support NeVA. As such I'm using:

def thunder_backend(gm, args):
  gm.real_recompile()
  from thunder.examine import examine
  try:   # Examine may raise an exception
      thunder.examine.examine(gm, *args)
  except Exception as e:
      print(f"Hit problem with examine:\n{e}")
  # Don't really use Thunder just return the original graph
  return gm

...
#model.model = thunder.jit(model.model)
model.model = torch.compile(backend=thunder_backend)(model.model)

(thanks Ivan for the great idea!)

And one of the issues that gets reported is e.g.

Found 2 distinct operations, of which 0 (0.0%) are supported
Please file an issue requesting the following operators here: https://github.com/Lightning-AI/lightning-thunder/issues/
new
_enter_autocast of torch.amp.autocast_mode
_exit_autocast of torch.amp.autocast_mode

Pitch

This looks like it is going to be important for #343.

Alternatives / Potential work-arounds

Minimal Repro

cc @apaz-cli @crcrpar @tfogal

@tfogal tfogal added the program-coverage Requests for model and program coverage label Jul 22, 2024
@tfogal tfogal self-assigned this Jul 22, 2024
@tfogal
Copy link
Collaborator Author

tfogal commented Jul 22, 2024

Assigning to me until I can fill out a better reproducer

@tfogal tfogal added nemo Issues needed to support NVIDIA NeMo models. high priority labels Jul 22, 2024
@t-vi t-vi added the dynamo label Jul 23, 2024
@IvanYashchuk
Copy link
Collaborator

Thunder doesn't support PyTorch's context managers like autocast, no_grad, enable_grad, etc. inside the compiled function. With the recent addition of the autocast-specific dispatch at tracing time (#705, #810) supporting this might not take a lot of work, the challenge is not to reorder these enter and exit calls inappropriately.
Another way to approach this problem is to ensure Thunder never sees these calls by adding more graph breaks.

@kshitij12345
Copy link
Collaborator

I think this could be the minimal repro:

import torch
import thunder

class ThunderJitBackend:
    def __init__(self, **compile_options) -> None:        
        self.thunder_jit_fns = []
        self.dynamo_graphs = []
        self.cnt = 0
        self.compile_options = compile_options

    def compile(self, gm, sample_args):
        self.dynamo_graphs.append(gm)
        gm.real_recompile()
        thunder_jit_fn = thunder.jit(gm, **self.compile_options)
        self.thunder_jit_fns.append(thunder_jit_fn)
        self.cnt += 1
        return thunder_jit_fn

dev = "cuda"

def foo(x):
    with torch.autocast(dev, torch.bfloat16):
        y = x @ x
    return x + 2, y

with torch.device(dev):
    model = foo
    x = torch.randn(16, 16)
    args = (x,)
    kwargs = {}

jit_backend = ThunderJitBackend()
cmodel = torch.compile(model, backend=jit_backend.compile)

o = cmodel(*args, **kwargs)
print(f"GRAPHS {jit_backend.cnt}")
print(jit_backend.dynamo_graphs[0])

for tfn in jit_backend.thunder_jit_fns:
    print(thunder.last_traces(tfn)[-1])

torch.testing.assert_close(o, model(*args, **kwargs))

Dynamo Graph

def forward(self, L_x_ : torch.Tensor):
    l_x_ = L_x_
    _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', torch.bfloat16, True, None)
    y = l_x_ @ l_x_
    _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast);  _enter_autocast = None
    add = l_x_ + 2;  l_x_ = None
    return (add, y)

With dev="cuda", we see the following error

  File "/home/kkalambarkar/git/pytorch/torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "/home/kkalambarkar/git/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/kkalambarkar/git/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/module.py", line 61, in forward
    res = self._forward_fn(*args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 685, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 225, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 506, in get_computation_and_inputs
    jit_results: TraceResults = interpreter(
  File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 213, in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/jit_ext.py", line 1768, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/interpreter.py", line 6760, in fn_
    raise InterpreterError(msg) from e
thunder.core.interpreter.InterpreterError: Encountered exception TypeError: unhashable type: 'instancemethod' while tracing GraphModule()

With dev="cpu", the program compiles but silently ignores the autocast in the function (computes in single precision), failing at torch.testing.assert_close.

@tfogal
Copy link
Collaborator Author

tfogal commented Jul 23, 2024

Removing my assignment because Kshiteej is a hero w.r.t. finding minimal reproducers 😄. Thank you

@tfogal tfogal removed their assignment Jul 23, 2024
@tfogal
Copy link
Collaborator Author

tfogal commented Jul 23, 2024

Another way to approach this problem is to ensure Thunder never sees these calls by adding more graph breaks.

Yes, this is a great idea for the interim. I'll see if these are in actionable parts of the code (NeMo, or maybe megatron).

@tfogal
Copy link
Collaborator Author

tfogal commented Aug 28, 2024

Some automation revealed an even simpler reproducer:

import torch
import thunder

class DynamoModule(torch.nn.Module):
    def forward(self):
        _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', torch.bfloat16, True, None)
        _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast);  _enter_autocast = _exit_autocast = None
        return ()

inputs = [
]
fqn = thunder.jit(DynamoModule())
fqn(*inputs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
amp dynamo nemo Issues needed to support NVIDIA NeMo models. program-coverage Requests for model and program coverage triage review
Projects
None yet
Development

No branches or pull requests

4 participants