Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate Memory and Performance difference using nvfuser vs torch.compile executor on Qwen2 #1552

Open
kshitij12345 opened this issue Dec 13, 2024 · 2 comments
Assignees
Labels
high priority memory use nemo Issues needed to support NVIDIA NeMo models. performance thunderfx for things that could be applicable to the dynamo+thunder frontend

Comments

@kshitij12345
Copy link
Collaborator

kshitij12345 commented Dec 13, 2024

On internal image pjnl-20241213 and on H100 -

With ("sdpa", "torchcompile_cat", "nvfuser") -

# Memory - 49690.704896
# <torch.utils.benchmark.utils.common.Measurement object at 0x7ef10052ffe0>
# run_forward_backward()
#   187.69 ms
#   1 measurement, 10 runs , 1 thread

With ("sdpa", "torchcompile") -

# Memory - 40340.889088
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fc319acd6d0>
# run_forward_backward()
#   153.62 ms
#   1 measurement, 10 runs , 1 thread

We should investigate what is happening leading to the difference in memory and perf.

import torch
import torch.utils.benchmark
from thunder.dynamo import ThunderCompiler
from transformers import AutoConfig, AutoModelForCausalLM
import thunder

model_id = "Qwen/Qwen2.5-7B-Instruct"

configuration = AutoConfig.from_pretrained(
    model_id,
    # num_hidden_layers=2,
)
configuration.hidden_size = configuration.num_attention_heads
with torch.device("cuda"):
    model = AutoModelForCausalLM.from_config(configuration).to(torch.bfloat16)

# executors = ("sdpa", "torchcompile_cat", "nvfuser")
executors = ("sdpa", "torchcompile")
backend = ThunderCompiler(executors=executors)
compiled_model = torch.compile(model, backend=backend)

input_ids = torch.randint(0, configuration.vocab_size, (1, 4096), device="cuda")

def run_forward_backward():
    compiled_output = compiled_model(input_ids=input_ids, labels=input_ids)
    compiled_output.loss.backward()


for _ in range(5):
    run_forward_backward()

print(torch.cuda.max_memory_allocated() / 1e6)

import torch
timer = torch.utils.benchmark.Timer("run_forward_backward()", globals={"run_forward_backward": run_forward_backward})
measurement = timer.timeit(number=10)
print(measurement)

# With Nvfuser executor
# Memory - 49690.704896
# <torch.utils.benchmark.utils.common.Measurement object at 0x7ef10052ffe0>
# run_forward_backward()
#   187.69 ms
#   1 measurement, 10 runs , 1 thread

# With torch.compile executor
# Memory - 40340.889088
# <torch.utils.benchmark.utils.common.Measurement object at 0x7fc319acd6d0>
# run_forward_backward()
#   153.62 ms
#   1 measurement, 10 runs , 1 thread

cc @apaz-cli @tfogal

@tfogal tfogal added thunderfx for things that could be applicable to the dynamo+thunder frontend high priority nemo Issues needed to support NVIDIA NeMo models. labels Dec 13, 2024
@IvanYashchuk
Copy link
Collaborator

There could be a conflict between the "torchcompile_cat" and "nvfuser" executors, creating more fusions and materializing more intermediates than necessary. Inspecting execution traces could reveal if that's happening. If yes, a potential solution could be expanding the scope of the "torchcompile_cat" executor to take more operations into its fusion.

@riccardofelluga
Copy link
Collaborator

riccardofelluga commented Dec 18, 2024

From initial triage, it looks like nvFuser fusion pass in Thunder is picking up a bunch of transpose and reshape ops from around the trace and fuse them together at the start of the backward pass together with the actual computation needed for the torch_nll_loss_backward_impl(as can be seen here in nvFusion0):

[... omitted ...]
value_states_2, value_states_5, = C0
clear_mutable_collection(C0)
del C0
[t1148, t1190, t1202, t1275, t1284, t1286, t1307, t1309, t1501, t1578, t1590, t1663, t1672, t1674, t1695, t1697, t1897] = nvFusion0(hidden_states_24, mul_19, hidden_states_19, attn_output_6, value_states_5, t651, t627, t620, hidden_states_13, mul_10, hidden_states_8, attn_output_2, value_states_2, t451, t429, t425, hidden_states_2)
  # t1148 = prims.reshape(hidden_states_24, (4096, 28))  # t1148: "cuda:0 bf16[4096, 28]"
  # t1190 = prims.reshape(mul_19, (4096, 18944))  # t1190: "cuda:0 bf16[4096, 18944]"
  # t1202 = prims.reshape(hidden_states_19, (4096, 28))  # t1202: "cuda:0 bf16[4096, 28]"
  # t1275 = prims.reshape(attn_output_6, (4096, 28))  # t1275: "cuda:0 bf16[4096, 28]"
  # t1284 = prims.transpose(value_states_5, (0, 1, 3, 2))  # t1284: "cuda:0 bf16[1, 28, 1, 4096]"
  # t1286 = prims.transpose(t651, (0, 1, 3, 2))  # t1286: "cuda:0 bf16[1, 28, 4096, 4096]"
  # t1307 = prims.transpose(t627, (0, 1, 3, 2))  # t1307: "cuda:0 bf16[1, 28, 4096, 2]"
  # t1309 = prims.transpose(t620, (0, 1, 3, 2))  # t1309: "cuda:0 bf16[1, 28, 2, 4096]"
  # t1501 = prims.reshape(hidden_states_13, (4096, 28))  # t1501: "cuda:0 bf16[4096, 28]"
  # t1578 = prims.reshape(mul_10, (4096, 18944))  # t1578: "cuda:0 bf16[4096, 18944]"
  # t1590 = prims.reshape(hidden_states_8, (4096, 28))  # t1590: "cuda:0 bf16[4096, 28]"
  # t1663 = prims.reshape(attn_output_2, (4096, 28))  # t1663: "cuda:0 bf16[4096, 28]"
  # t1672 = prims.transpose(value_states_2, (0, 1, 3, 2))  # t1672: "cuda:0 bf16[1, 28, 1, 4096]"
  # t1674 = prims.transpose(t451, (0, 1, 3, 2))  # t1674: "cuda:0 bf16[1, 28, 4096, 4096]"
  # t1695 = prims.transpose(t429, (0, 1, 3, 2))  # t1695: "cuda:0 bf16[1, 28, 4096, 2]"
  # t1697 = prims.transpose(t425, (0, 1, 3, 2))  # t1697: "cuda:0 bf16[1, 28, 2, 4096]"
  # t1897 = prims.reshape(hidden_states_2, (4096, 28))  # t1897: "cuda:0 bf16[4096, 28]"
del hidden_states_24, mul_19, hidden_states_19, attn_output_6, value_states_5, t627, t620, hidden_states_13, mul_10, hidden_states_8, attn_output_2, value_states_2, t429, t425, hidden_states_2
t1129 = torch_nll_loss_backward_impl(t365, t744, shift_labels_1, None, 'mean', -100, t752)  # t1129: "cuda:0 f32[4095, 152064]"
del t365, shift_labels_1, t752
[t1143, t1147] = nvFusion1(t744, t1129, t366)
[... omitted ...]

This is bad because, already without nvFusion0, the start of the backward pass is where peak memory usage happens.

For what I see I think the work here is two fold: we can set nvFuser to not pick up transform/reshape ops and then work on the scheduling of computation with the idea of putting the production of tensors closer to the consumers(this idea is the main thread in at least #1337, #1560 and #1562)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority memory use nemo Issues needed to support NVIDIA NeMo models. performance thunderfx for things that could be applicable to the dynamo+thunder frontend
Projects
None yet
Development

No branches or pull requests

4 participants