diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 00184b2084..2f4ead75b1 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -2,7 +2,6 @@ import os import time import warnings -from typing import Any from contextlib import nullcontext import torch @@ -10,7 +9,6 @@ from torch.utils.data import DataLoader, IterableDataset import torch.distributed as torch_dist from torch.distributed.device_mesh import init_device_mesh -import warnings from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( apply_activation_checkpointing, @@ -35,6 +33,7 @@ transformer_engine_available = False +FSDP_MODES: set[str] = {"fsdp", "fsdp2"} world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) global_rank = int(os.environ.get("RANK", 0)) @@ -77,30 +76,6 @@ def configure_optimizers(model, weight_decay, learning_rate, betas, device_type) return optimizer -# NOTE(crcrpar): Calling this method seems to bloat the memory consumption to some extent. -# e.g. ref: https://github.com/Lightning-AI/lightning-thunder/issues/439 -def _run_fwd_bwd_one_microbatch( - model: torch.nn.Module, - input_ids: torch.Tensor, - targets: torch.Tensor, - gradient_accumulation_steps: int, - device: torch.device, - use_te_fp8_autocast: bool, -) -> torch.Tensor: - input_ids = input_ids.to(device) - targets = targets.to(device) - if use_te_fp8_autocast: - with te.fp8_autocast(enabled=True): - logits = model(input_ids) - else: - logits = model(input_ids) - logits = logits.reshape(-1, logits.size(-1)) - targets = targets.reshape(-1) - loss = torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) / gradient_accumulation_steps - loss.backward() - return loss - - class Benchmark_litGPT: def __init__( self, @@ -124,6 +99,8 @@ def __init__( low_precision_mode: str = "none", max_iters: int = 45, warmup_iters: int = 25, + dump_thunder_traces: bool = False, + dump_memory_snapshot: bool = False, ): seed = 1337 torch.manual_seed(seed) @@ -154,6 +131,9 @@ def __init__( self.micro_batch_size = micro_batch_size self.low_precision_mode = low_precision_mode self.use_te_fp8_autocast = is_transformer_engine(low_precision_mode) and "thunder" not in compile + self.is_thunder_as_torchcompile_backend = False + self.dump_thunder_traces = dump_thunder_traces + self.dump_memory_snapshot = dump_memory_snapshot # Clarify benchmark assumptions if self.sharding_size is not None: @@ -170,10 +150,10 @@ def __init__( world_size % self.sharding_size == 0 ), f"World size {world_size} is not divisible by the sharding size {self.sharding_size}" - if self.bucketing_mode != "none" and self.distributed_mode != "fsdp": - print( - f"[WARNING] --bucketing_mode {self.bucketing_mode} will be ignored as \ - it is only used for FSDP style parallelism but running {self.distributed_mode}" + if self.bucketing_mode != "none" and self.distributed_mode not in FSDP_MODES: + warnings.warn( + f"--bucketing_mode {self.bucketing_mode} will be ignored as " + f" it is only used for FSDP style parallelism but running {self.distributed_mode}" ) assert not ( @@ -181,15 +161,13 @@ def __init__( ), "'size' bucketing mode is not supported for Thunder. Please use 'none' or 'block'." if self.fsdp_bucket_params is not None: - if self.distributed_mode != "fsdp": - print( - f"[WARNING] Found --fsdp_bucket_params but Distributed mode is {self.distributed_mode}. Will be ignored" + if self.distributed_mode not in FSDP_MODES: + warnings.warn( + f"Found --fsdp_bucket_params but Distributed mode is {self.distributed_mode}. Will be ignored" ) if self.bucketing_mode != "size": - print( - f"[WARNING] Bucketing mode is set to {self.bucketing_mode}. --fsdp_bucket_params will be ignored." - ) + warnings.warn(f"Bucketing mode is set to {self.bucketing_mode}. --fsdp_bucket_params will be ignored.") if is_transformer_engine(low_precision_mode): if not transformer_engine_available: @@ -271,7 +249,7 @@ def __init__( } def init_model(self): - init_device = torch.device("meta") if self.distributed_mode == "fsdp" else self.device + init_device = torch.device("meta") if self.distributed_mode in FSDP_MODES else self.device with init_device: model = GPT(self.config) model.to(dtype=torch.bfloat16) @@ -307,6 +285,11 @@ def setup_distributed(self, model): sharding_strategy=sharding_strategy, bucketing_strategy=bucketing_strategy, ) + else: + if self.distributed_mode == "fsdp2": + raise ValueError( + f"To use `fsdp2`, use thunder as torch.compile backend by including dynamo in `--compile` option or set `--compile` to either eager or inductor" + ) else: if self.distributed_mode == "ddp": model = torch.nn.parallel.DistributedDataParallel( @@ -314,6 +297,43 @@ def setup_distributed(self, model): device_ids=[local_rank], bucket_cap_mb=self.ddp_bucket_size, ) + elif self.distributed_mode == "fsdp2": + # reference: https://github.com/pytorch/torchtitan/blob/6e7a183/docs/fsdp.md + from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy + + if self.bucketing_mode != "none": + warnings.warn(f"fsdp2 ignores {self.bucketing_mode=}") + + torch.cuda.set_device(local_rank) + mesh = None + if self.sharding_size is not None: + mesh = init_device_mesh("cuda", (int(world_size / self.sharding_size), self.sharding_size)) + else: + mesh = init_device_mesh("cuda", (world_size,)) + + reshard_after_forward: bool = self.shard_mode == "zero3" + + for transformer_block in model.modules(): + if isinstance(transformer_block, Block): + fully_shard( + transformer_block, + mesh=mesh, + reshard_after_forward=reshard_after_forward, + mp_policy=MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + ), + ) + + fully_shard( + model, + mesh=mesh, + reshard_after_forward=reshard_after_forward, + mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16), + ) + model.to_empty(device=self.device) + model.apply(model._init_weights) + elif self.distributed_mode == "fsdp": from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy, size_based_auto_wrap_policy @@ -396,6 +416,24 @@ def setup_compile(self, model): executors.insert(0, transformer_engine_ex) if "dynamo" in self.compile: + if self.distributed_mode == "fsdp2": + print("Resetting cache size for when fsdp2 and using thunder as backend torch.compile") + import torch._dynamo.config as dynamo_config + + dynamo_config.cache_size_limit = 64 + if "transformerengine" in self.compile: + # [rank0]: File "/opt/pytorch/lightning-thunder/thunder/executors/transformer_engineex.py", line 410, in _te_functional_linear_backward_impl + # [rank0]: grads = _Linear.backward(ctx, g) + # [rank0]: File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/linear.py", line 449, in backward + # [rank0]: weight_fp8.transpose_2d(), + # [rank0]: File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/float8_tensor.py", line 625, in transpose_2d + # [rank0]: if self._transpose is None: + # [rank0]: File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/float8_tensor.py", line 39, in get_func + # [rank0]: return self._fp8_attrs[name] + # [rank0]: AttributeError: 'Float8Tensor' object has no attribute '_fp8_attrs' + raise ValueError( + "TransformerEngine executor cannot be used as an executor of Thunder when Thunder is used as torch.compile backend" + ) backend = ThunderCompiler(executors=executors) # Because Lightning Fabric is imported in this script it monkey patches the torch.compile function # https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421 @@ -466,6 +504,8 @@ def train(self): iter_t0 = time.perf_counter() if i == self.warmup_iters: # warmup t0 = iter_t0 + if self.dump_memory_snapshot and global_rank in (0, None): + torch.cuda.memory._record_memory_history() if self.nsys_enabled and i == self.profiler_start and global_rank in [0, None]: print("=====Start NSYS Profiling======") @@ -549,7 +589,7 @@ def add_model_info_to_metrics(self): self.perf_metrics["Micro BS"] = self.micro_batch_size self.perf_metrics["Global BS"] = self.global_batch_size self.perf_metrics["GA"] = self.gradient_accumulation_steps - if self.distributed_mode in ["fsdp"]: + if self.distributed_mode in FSDP_MODES: self.perf_metrics["Distributed Mode"] = ( str(self.distributed_mode) + "_" @@ -602,7 +642,7 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None f"Number of Layers: {benchmark.config.n_layer}\nNumber of parameters: {sum(p.numel() for p in benchmark.model.parameters() if p.requires_grad) / 1e9:.02f}B" ) print(f"Distributed Mode: {benchmark.distributed_mode}") - if benchmark.distributed_mode == "fsdp": + if benchmark.distributed_mode in FSDP_MODES: print(f"Sharding Mode: {benchmark.shard_mode}\nBucketing: {benchmark.bucketing_mode}") if benchmark.sharding_size is not None: print( @@ -620,6 +660,37 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}") print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}") + if benchmark.dump_memory_snapshot: + file_name = f"{benchmark.model_name}_{benchmark.compile}_{benchmark.distributed_mode}" + if benchmark.distributed_mode.startswith("fsdp"): + file_name = f"{file_name}_{benchmark.shard_mode}" + if benchmark.distributed_mode == "fsdp": + file_name += f"_{benchmark.bucketing_mode}" + if benchmark.distributed_mode == "ddp": + file_name += f"_{benchmark.ddp_bucket_size}" + file_name = f"{file_name}.pickle" + print(f"Dump memory snapshot at {file_name}") + torch.cuda.memory._dump_snapshot(file_name) + torch.cuda.memory._record_memory_history(enabled=None) + + if benchmark.dump_thunder_traces: + if benchmark.is_thunder_as_torchcompile_backend: + print(f"{len(benchmark.thunder_as_torch_compile_backend.gm_to_thunder)} ThunderModule's are created") + fwd_traces, bwd_traces = [], [] + for jitted in benchmark.thunder_as_torch_compile_backend.gm_to_thunder.values(): + fwd_traces.append(thunder.last_traces(jitted)) + bwd_traces.append(thunder.last_backward_traces(jitted)) + else: + fwd_traces = [thunder.last_traces(benchmark.model)] + bwd_traces = [thunder.last_backward_traces(benchmark.model)] + + for i, f_traces in enumerate(fwd_traces, start=1): + print(f"##########\n#{i}-th ThunderModule\n##########") + print(f_traces[-1]) + for i, b_traces in enumerate(bwd_traces, start=1): + print(f"##########\n#{i}-th ThunderModule\n##########") + print(b_traces[-1]) + if global_rank in [0, None]: if return_metrics_as_json: benchmark.add_model_info_to_metrics()