Skip to content

Commit

Permalink
[benchmark_litgpt] Add an option to use FSDP2 for when `torch.compile…
Browse files Browse the repository at this point in the history
…` is used (#940)

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar authored Aug 19, 2024
1 parent dcf4782 commit d425fe4
Showing 1 changed file with 110 additions and 39 deletions.
149 changes: 110 additions & 39 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import os
import time
import warnings
from typing import Any
from contextlib import nullcontext

import torch
import functools
from torch.utils.data import DataLoader, IterableDataset
import torch.distributed as torch_dist
from torch.distributed.device_mesh import init_device_mesh
import warnings

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
Expand All @@ -35,6 +33,7 @@
transformer_engine_available = False


FSDP_MODES: set[str] = {"fsdp", "fsdp2"}
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
Expand Down Expand Up @@ -77,30 +76,6 @@ def configure_optimizers(model, weight_decay, learning_rate, betas, device_type)
return optimizer


# NOTE(crcrpar): Calling this method seems to bloat the memory consumption to some extent.
# e.g. ref: https://github.com/Lightning-AI/lightning-thunder/issues/439
def _run_fwd_bwd_one_microbatch(
model: torch.nn.Module,
input_ids: torch.Tensor,
targets: torch.Tensor,
gradient_accumulation_steps: int,
device: torch.device,
use_te_fp8_autocast: bool,
) -> torch.Tensor:
input_ids = input_ids.to(device)
targets = targets.to(device)
if use_te_fp8_autocast:
with te.fp8_autocast(enabled=True):
logits = model(input_ids)
else:
logits = model(input_ids)
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
loss = torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) / gradient_accumulation_steps
loss.backward()
return loss


class Benchmark_litGPT:
def __init__(
self,
Expand All @@ -124,6 +99,8 @@ def __init__(
low_precision_mode: str = "none",
max_iters: int = 45,
warmup_iters: int = 25,
dump_thunder_traces: bool = False,
dump_memory_snapshot: bool = False,
):
seed = 1337
torch.manual_seed(seed)
Expand Down Expand Up @@ -154,6 +131,9 @@ def __init__(
self.micro_batch_size = micro_batch_size
self.low_precision_mode = low_precision_mode
self.use_te_fp8_autocast = is_transformer_engine(low_precision_mode) and "thunder" not in compile
self.is_thunder_as_torchcompile_backend = False
self.dump_thunder_traces = dump_thunder_traces
self.dump_memory_snapshot = dump_memory_snapshot

# Clarify benchmark assumptions
if self.sharding_size is not None:
Expand All @@ -170,26 +150,24 @@ def __init__(
world_size % self.sharding_size == 0
), f"World size {world_size} is not divisible by the sharding size {self.sharding_size}"

if self.bucketing_mode != "none" and self.distributed_mode != "fsdp":
print(
f"[WARNING] --bucketing_mode {self.bucketing_mode} will be ignored as \
it is only used for FSDP style parallelism but running {self.distributed_mode}"
if self.bucketing_mode != "none" and self.distributed_mode not in FSDP_MODES:
warnings.warn(
f"--bucketing_mode {self.bucketing_mode} will be ignored as "
f" it is only used for FSDP style parallelism but running {self.distributed_mode}"
)

assert not (
"thunder" in self.compile and self.bucketing_mode == "size"
), "'size' bucketing mode is not supported for Thunder. Please use 'none' or 'block'."

if self.fsdp_bucket_params is not None:
if self.distributed_mode != "fsdp":
print(
f"[WARNING] Found --fsdp_bucket_params but Distributed mode is {self.distributed_mode}. Will be ignored"
if self.distributed_mode not in FSDP_MODES:
warnings.warn(
f"Found --fsdp_bucket_params but Distributed mode is {self.distributed_mode}. Will be ignored"
)

if self.bucketing_mode != "size":
print(
f"[WARNING] Bucketing mode is set to {self.bucketing_mode}. --fsdp_bucket_params will be ignored."
)
warnings.warn(f"Bucketing mode is set to {self.bucketing_mode}. --fsdp_bucket_params will be ignored.")

if is_transformer_engine(low_precision_mode):
if not transformer_engine_available:
Expand Down Expand Up @@ -271,7 +249,7 @@ def __init__(
}

def init_model(self):
init_device = torch.device("meta") if self.distributed_mode == "fsdp" else self.device
init_device = torch.device("meta") if self.distributed_mode in FSDP_MODES else self.device
with init_device:
model = GPT(self.config)
model.to(dtype=torch.bfloat16)
Expand Down Expand Up @@ -307,13 +285,55 @@ def setup_distributed(self, model):
sharding_strategy=sharding_strategy,
bucketing_strategy=bucketing_strategy,
)
else:
if self.distributed_mode == "fsdp2":
raise ValueError(
f"To use `fsdp2`, use thunder as torch.compile backend by including dynamo in `--compile` option or set `--compile` to either eager or inductor"
)
else:
if self.distributed_mode == "ddp":
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
bucket_cap_mb=self.ddp_bucket_size,
)
elif self.distributed_mode == "fsdp2":
# reference: https://github.com/pytorch/torchtitan/blob/6e7a183/docs/fsdp.md
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy

if self.bucketing_mode != "none":
warnings.warn(f"fsdp2 ignores {self.bucketing_mode=}")

torch.cuda.set_device(local_rank)
mesh = None
if self.sharding_size is not None:
mesh = init_device_mesh("cuda", (int(world_size / self.sharding_size), self.sharding_size))
else:
mesh = init_device_mesh("cuda", (world_size,))

reshard_after_forward: bool = self.shard_mode == "zero3"

for transformer_block in model.modules():
if isinstance(transformer_block, Block):
fully_shard(
transformer_block,
mesh=mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
),
)

fully_shard(
model,
mesh=mesh,
reshard_after_forward=reshard_after_forward,
mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16),
)
model.to_empty(device=self.device)
model.apply(model._init_weights)

elif self.distributed_mode == "fsdp":
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy, size_based_auto_wrap_policy
Expand Down Expand Up @@ -396,6 +416,24 @@ def setup_compile(self, model):
executors.insert(0, transformer_engine_ex)

if "dynamo" in self.compile:
if self.distributed_mode == "fsdp2":
print("Resetting cache size for when fsdp2 and using thunder as backend torch.compile")
import torch._dynamo.config as dynamo_config

dynamo_config.cache_size_limit = 64
if "transformerengine" in self.compile:
# [rank0]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 410, in _te_functional_linear_backward_impl
# [rank0]: grads = _Linear.backward(ctx, g)
# [rank0]: File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/linear.py", line 449, in backward
# [rank0]: weight_fp8.transpose_2d(),
# [rank0]: File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/float8_tensor.py", line 625, in transpose_2d
# [rank0]: if self._transpose is None:
# [rank0]: File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/float8_tensor.py", line 39, in get_func
# [rank0]: return self._fp8_attrs[name]
# [rank0]: AttributeError: 'Float8Tensor' object has no attribute '_fp8_attrs'
raise ValueError(
"TransformerEngine executor cannot be used as an executor of Thunder when Thunder is used as torch.compile backend"
)
backend = ThunderCompiler(executors=executors)
# Because Lightning Fabric is imported in this script it monkey patches the torch.compile function
# https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421
Expand Down Expand Up @@ -466,6 +504,8 @@ def train(self):
iter_t0 = time.perf_counter()
if i == self.warmup_iters: # warmup
t0 = iter_t0
if self.dump_memory_snapshot and global_rank in (0, None):
torch.cuda.memory._record_memory_history()

if self.nsys_enabled and i == self.profiler_start and global_rank in [0, None]:
print("=====Start NSYS Profiling======")
Expand Down Expand Up @@ -549,7 +589,7 @@ def add_model_info_to_metrics(self):
self.perf_metrics["Micro BS"] = self.micro_batch_size
self.perf_metrics["Global BS"] = self.global_batch_size
self.perf_metrics["GA"] = self.gradient_accumulation_steps
if self.distributed_mode in ["fsdp"]:
if self.distributed_mode in FSDP_MODES:
self.perf_metrics["Distributed Mode"] = (
str(self.distributed_mode)
+ "_"
Expand Down Expand Up @@ -602,7 +642,7 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
f"Number of Layers: {benchmark.config.n_layer}\nNumber of parameters: {sum(p.numel() for p in benchmark.model.parameters() if p.requires_grad) / 1e9:.02f}B"
)
print(f"Distributed Mode: {benchmark.distributed_mode}")
if benchmark.distributed_mode == "fsdp":
if benchmark.distributed_mode in FSDP_MODES:
print(f"Sharding Mode: {benchmark.shard_mode}\nBucketing: {benchmark.bucketing_mode}")
if benchmark.sharding_size is not None:
print(
Expand All @@ -620,6 +660,37 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}")
print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}")

if benchmark.dump_memory_snapshot:
file_name = f"{benchmark.model_name}_{benchmark.compile}_{benchmark.distributed_mode}"
if benchmark.distributed_mode.startswith("fsdp"):
file_name = f"{file_name}_{benchmark.shard_mode}"
if benchmark.distributed_mode == "fsdp":
file_name += f"_{benchmark.bucketing_mode}"
if benchmark.distributed_mode == "ddp":
file_name += f"_{benchmark.ddp_bucket_size}"
file_name = f"{file_name}.pickle"
print(f"Dump memory snapshot at {file_name}")
torch.cuda.memory._dump_snapshot(file_name)
torch.cuda.memory._record_memory_history(enabled=None)

if benchmark.dump_thunder_traces:
if benchmark.is_thunder_as_torchcompile_backend:
print(f"{len(benchmark.thunder_as_torch_compile_backend.gm_to_thunder)} ThunderModule's are created")
fwd_traces, bwd_traces = [], []
for jitted in benchmark.thunder_as_torch_compile_backend.gm_to_thunder.values():
fwd_traces.append(thunder.last_traces(jitted))
bwd_traces.append(thunder.last_backward_traces(jitted))
else:
fwd_traces = [thunder.last_traces(benchmark.model)]
bwd_traces = [thunder.last_backward_traces(benchmark.model)]

for i, f_traces in enumerate(fwd_traces, start=1):
print(f"##########\n#{i}-th ThunderModule\n##########")
print(f_traces[-1])
for i, b_traces in enumerate(bwd_traces, start=1):
print(f"##########\n#{i}-th ThunderModule\n##########")
print(b_traces[-1])

if global_rank in [0, None]:
if return_metrics_as_json:
benchmark.add_model_info_to_metrics()
Expand Down

0 comments on commit d425fe4

Please sign in to comment.