Skip to content

Complete torch.compile executor #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
May 3, 2024
Merged

Complete torch.compile executor #140

merged 39 commits into from
May 3, 2024

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Apr 5, 2024

What does this PR do?

Adds an instance of TorchCompileExecutor named torch_compile_Ex that registers all of pytorch_executor's and sdpa_ex's operators.

The rename of the existing executor torch_compile_executor to torch_compile_cat_ex breaks backwards compatibility.

Playground script with litgpt:

from litgpt import GPT
import thunder
import torch

from thunder import pytorch_executor
from thunder.executors.torch_compile import torch_compile_ex

with torch.device("cuda"):
    model = GPT.from_name("Llama-2-7b-hf", n_layer=1)

model = thunder.jit(model, executors=[torch_compile_ex])
x = torch.randint(model.max_seq_length, (2, 5), device="cuda")
y = model(x)

forward_trace = thunder.last_traces(model)[-1].python()
print(forward_trace)
assert "TorchCompile" in str(forward_trace)

Output:

@torch.no_grad()
@no_autocast
def augmented_forward_fn(idx, t_lm_head_weight, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight, tos1, t_sin):
  # idx: "cuda:0 i64[2, 5]"
  # t_lm_head_weight: "cuda:0 f32[32000, 4096]"
  # t_transformer_h_0_attn_attn_weight: "cuda:0 f32[12288, 4096]"
  # t_transformer_h_0_attn_proj_weight: "cuda:0 f32[4096, 4096]"
  # t_transformer_h_0_mlp_fc_1_weight: "cuda:0 f32[11008, 4096]"
  # t_transformer_h_0_mlp_fc_2_weight: "cuda:0 f32[11008, 4096]"
  # t_transformer_h_0_mlp_proj_weight: "cuda:0 f32[4096, 11008]"
  # t_transformer_h_0_norm_1_weight: "cuda:0 f32[4096]"
  # t_transformer_h_0_norm_2_weight: "cuda:0 f32[4096]"
  # t_transformer_ln_f_weight: "cuda:0 f32[4096]"
  # t_transformer_wte_weight: "cuda:0 f32[32000, 4096]"
  # tos1: "cuda:0 f32[4096, 128]"
  # t_sin: "cuda:0 f32[4096, 128]"
  [t10, t101, t106, t70, t81] = TorchCompile0(idx, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight, tos1)
  return {'output': t106, 'flat_args': [idx, t_lm_head_weight, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight, tos1, t_sin], 'flat_output': (t106,)}, ((idx, t10, t101, t70, t81, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight, tos1), (False, False, 0.29730177875068026, 0.29730177875068026, 4096.0, 4096.0, 4096.0, 32000, 2, -1))

Fixes https://github.com/Lightning-AI/lit-thunder-LEGACY/issues/2141

cc @Borda @apaz-cli

@carmocca carmocca self-assigned this Apr 5, 2024
@carmocca

This comment was marked as outdated.

@IvanYashchuk
Copy link
Collaborator

Rename the existing torch_compile executor to torch_compile_partial

I suggest renaming the existing executor to "concat_inductor". This is what it does, uses Inductor to fuse concatenation and surrounding operations.

@carmocca carmocca force-pushed the carmocca/complete-torch-compile branch from f2e355b to 75bf7ea Compare April 10, 2024 12:09
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@github-actions github-actions bot added documentation Improvements or additions to documentation and removed has conflicts labels Apr 23, 2024
@carmocca carmocca marked this pull request as ready for review April 29, 2024 17:05
@carmocca carmocca force-pushed the carmocca/complete-torch-compile branch from 58077e9 to ba75ad8 Compare April 29, 2024 17:45
@carmocca
Copy link
Contributor Author

Are these known flakes on Windows?

FAILED thunder/tests/test_grad.py::test_vjp_correctness_getitem_torch_cpu_float64 - AssertionError: Scalars are not close!

Expected 75.60970520045251 but got 81.13586235571096.
Absolute difference: 5.526157155258446 (up to 1e-05 allowed)
Relative difference: 0.07308793415617461 (up to 1.3e-06 allowed)
FAILED thunder/tests/test_grad.py::test_phantom_grad_vs_torch_consistency_getitem_torch_cpu_bfloat16 - AssertionError: Tensor-likes are not close!

Mismatched elements: 112 / 140 (80.0%)
Greatest absolute difference: 4.0 at index (0, 0, 0) (up to 1e-05 allowed)
Greatest relative difference: 2.0 at index (0, 0, 0) (up to 0.016 allowed)
FAILED thunder/tests/test_grad.py::test_phantom_grad_vs_torch_consistency_getitem_torch_cpu_float64 - AssertionError: Tensor-likes are not close!

Mismatched elements: 112 / 140 (80.0%)
Greatest absolute difference: 7.0 at index (0, 0, 0) (up to 1e-07 allowed)
Greatest relative difference: 3.5 at index (0, 0, 0) (up to 1e-07 allowed)
FAILED thunder/tests/test_grad.py::test_phantom_grad_vs_torch_consistency_getitem_torch_cpu_float32 - AssertionError: Tensor-likes are not close!

Mismatched elements: 112 / 140 (80.0%)
Greatest absolute difference: 5.0 at index (0, 0, 0) (up to 1e-05 allowed)
Greatest relative difference: 2.5 at index (0, 0, 0) (up to 1.3e-06 allowed)

@apaz-cli
Copy link
Contributor

@carmocca I usually assume that consistency tests are flakes, yeah. Re-run, and it should go away. If it doesn't, then it wasn't a flake :)

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Awesome work. Thank you @carmocca @IvanYashchuk @apaz-cli

@t-vi t-vi merged commit 7ac5684 into main May 3, 2024
36 of 39 checks passed
@t-vi t-vi deleted the carmocca/complete-torch-compile branch May 3, 2024 12:31
Comment on lines +269 to +274
if "inductor_cat" in self.compile:
from thunder.executors.torch_compile import torch_compile_cat_ex as torch_compile_ex

executors.insert(0, torch_compile_ex)
elif "inductor" in self.compile:
from thunder.executors.torch_compile import torch_compile_ex
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to highlight that our nightly scripts and reporting depend on the current naming convention to monitor performance history. There was no real need to modify these benchmark options names in this PR in a breaking way. In the future, I suggest we explore different alternatives for modifications that alter existing behavior before finalizing the merge. Additionally, after merging, it's important to communicate these changes through various channels, not limited to GitHub, to ensure everyone is informed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request executors torch.compile
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants