Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Repro function saved from FX graph is segmented again when passed back to torch.compile #1521

Closed
kiya00 opened this issue Dec 5, 2024 · 6 comments · Fixed by #1540
Closed
Labels
benchmarking debugging thunderfx for things that could be applicable to the dynamo+thunder frontend

Comments

@kiya00
Copy link
Collaborator

kiya00 commented Dec 5, 2024

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

In the saved reproducer function, since it's already processed by dynamo, it's possible that there are some dynamo operators such as torch.amp.autocast_mode._enter_autocast. If we pass this repro function into torch.compile again, it creates graph-break on this kind of operator, which can result in slower perf for torch.compile case.

To Reproduce

import torch
from thunder.dynamo import ThunderCompiler

def test_autocast():
    x = torch.rand(2, 2, device="cuda", requires_grad=True)

    backend = ThunderCompiler()

    def func(x):
        x = x + 2
        with torch.autocast("cuda"):
            y = torch.sin(x)
            return y

    cfunc = torch.compile(func, backend=backend)
    actual = cfunc(x)
    backend.save_reproducer_to_folder("repro")

Run the above script it produces repro function as follows, we modifies it to print out the graph break information:

import torch
def test_graph0_thunder_0():
    class DynamoModule(torch.nn.Module):
      def forward(self, l_x_ : torch.Tensor):
          x = l_x_ + 2;  l_x_ = None
          _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', None, True, None)
          y = torch.sin(x);  x = None
          _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast);  _enter_autocast = _exit_autocast = None
          return y

    inputs = [
        torch.testing.make_tensor((2, 2), dtype=torch.float32,  device='cuda:0', requires_grad=True, low=0.32686540484428406, high=0.8776439428329468,).as_strided((2, 2), (2, 1)),

    ]

    mod = DynamoModule()
    exp = torch._dynamo.explain(mod)(*inputs)
    print(exp)
Graph Count: 2
Graph Break Count: 1
Op Count: 2
Break Reasons:
  Break Reason 1:
    Reason: torch.* op returned non-Tensor autocast call_function <function _enter_autocast at 0x76ad30d6d8a0>
    User Stack:
      <FrameSummary file /wayan/lightning-thunder/tmp.py, line 30 in forward>
  Break Reason 2:
    Reason: call_function args: AutocastModeVariable()
    User Stack:
      <FrameSummary file /wayan/lightning-thunder/tmp.py, line 32 in torch_dynamo_resume_in_forward_at_30>
Ops per Graph:
  Ops 1:
    <built-in function add>
  Ops 2:
    <built-in method sin of type object at 0x76ae6a584380>

Possible solution

Instead of using torch.compile(), we can try to pass the repro function directly to the inductor

import torch
def test_graph0_thunder_0():
    class DynamoModule(torch.nn.Module):
      def forward(self, l_x_ : torch.Tensor):
          x = l_x_ + 2;  l_x_ = None
          _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', None, True, None)
          y = torch.sin(x);  x = None
          _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast);  _enter_autocast = _exit_autocast = None
          return y

    inputs = [
        torch.testing.make_tensor((2, 2), dtype=torch.float32,  device='cuda:0', requires_grad=True, low=0.32686540484428406, high=0.8776439428329468,).as_strided((2, 2), (2, 1)),

    ]

    mod = DynamoModule()
    # compiled = torch.compile(mod)
    # compiled(*inputs)
    from torch._inductor.compile_fx import compile_fx
    from torch.fx import symbolic_trace
    fx_graph = symbolic_trace(mod)
    compiled = compile_fx(fx_graph, inputs)
    compiled(*inputs)

cc: @kshitij12345 @mruberry @IvanYashchuk

cc @crcrpar @apaz-cli

@kiya00 kiya00 added benchmarking debugging thunderfx for things that could be applicable to the dynamo+thunder frontend labels Dec 5, 2024
@mruberry
Copy link
Collaborator

mruberry commented Dec 5, 2024

Great issue, @kiya00! What a pain that it graph breaks on the code it generated!

@kiya00
Copy link
Collaborator Author

kiya00 commented Dec 6, 2024

Yeah, I didn't realize that before. When I tested the repro function of HF models on torch.compile backend, it runs very slow because of the graph break, here are some comparison results (and the repro script for each model is in here)

@mruberry Do you think it's OK to use torch._inductor.compile_fx instead of the torch.compile to get some comparable results when benchmarking the repro function?

@mruberry
Copy link
Collaborator

mruberry commented Dec 6, 2024

Yeah, I didn't realize that before. When I tested the repro function of HF models on torch.compile backend, it runs very slow because of the graph break, here are some comparison results (and the repro script for each model is in here)

@mruberry Do you think it's OK to use torch._inductor.compile_fx instead of the torch.compile to get some comparable results when benchmarking the repro function?

I think so, although @IvanYashchuk or @kshitij12345 may have some ideas. What differences do you think there would be? We can certainly use torch._inductor.compile_fx for now if you think it's best, and then review alternatives together later.

@kshitij12345
Copy link
Collaborator

I was wondering if using torch.compiler.allow_in_graph will work. It is used to treat the function as opaque to dynamo. If it works, then probably using this would be better as it is a public API otherwise torch._inductor.compile_fx sounds good to me.

Ref: https://pytorch.org/docs/stable/generated/torch.compiler.allow_in_graph.html

@kiya00
Copy link
Collaborator Author

kiya00 commented Dec 9, 2024

import torch

def forward(l_x_ : torch.Tensor):
    x = l_x_ + 2;  l_x_ = None
    _enter_autocast = torch.amp.autocast_mode._enter_autocast('cuda', None, True, None)
    y = torch.sin(x);  x = None
    _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast);  _enter_autocast = _exit_autocast = None
    return y

torch.compiler.allow_in_graph(forward)
# torch.compiler.allow_in_graph([torch.amp.autocast_mode._enter_autocast, torch.amp.autocast_mode._exit_autocast])
cfunc = torch.compile(forward)

x = torch.rand(2, 2, device="cuda", requires_grad=True)
cfunc(x)

The graph still breaks with allow_in_graph

@kshitij12345
Copy link
Collaborator

In that case, torch._inductor.compile_fx sounds good for now. Thanks for checking, @kiya00

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmarking debugging thunderfx for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
3 participants