diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 9d590b0538..2bda124681 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -1,3 +1,4 @@ +from datetime import timedelta import os import time import warnings @@ -39,7 +40,8 @@ # Avoids the allocator thrashing issue in PyTorch NCCL backend. # See https://github.com/Lightning-AI/lightning-thunder/issues/420 os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - torch_dist.init_process_group(backend="nccl") + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1" + torch_dist.init_process_group(backend="nccl", timeout=timedelta(minutes=5)) pg = torch_dist.distributed_c10d._get_default_group() device = torch.device("cuda", local_rank) torch.cuda.set_device(device) @@ -118,6 +120,8 @@ def __init__( profiler_stop: int = 15, skip_data_sync: bool = False, low_precision_mode: str = "none", + max_iters: int = 45, + warmup_iters: int = 25, ): seed = 1337 torch.manual_seed(seed) @@ -129,9 +133,9 @@ def __init__( beta1 = 0.9 beta2 = 0.95 - self.max_iters = 45 - self.warmup_iter = 25 - assert self.max_iters > self.warmup_iter + self.max_iters = max_iters + self.warmup_iters = warmup_iters + assert self.max_iters > self.warmup_iters self.device = device self.model_name = model_name @@ -436,7 +440,7 @@ def train(self): # Calculate the model FLOPs self.calculate_model_flops() # Setup throughput Collection - self.throughput = Throughput(window_size=self.max_iters - self.warmup_iter, world_size=world_size) + self.throughput = Throughput(window_size=self.max_iters - self.warmup_iters, world_size=world_size) except: self.throughput = None print( @@ -450,7 +454,7 @@ def train(self): for i in range(self.max_iters): iter_t0 = time.perf_counter() - if i == self.warmup_iter: # warmup + if i == self.warmup_iters: # warmup t0 = iter_t0 if self.nsys_enabled and i == self.profiler_start and global_rank in [0, None]: @@ -504,7 +508,7 @@ def train(self): print( f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}" ) - if i >= self.warmup_iter: + if i >= self.warmup_iters: if self.throughput: self.throughput.update( time=(t1 - t0), @@ -518,7 +522,7 @@ def train(self): if global_rank in [0, None]: # print(f"Total time: {(t1 - t0):.2f}s") - self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iter) + self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iters) def add_perf_metrics(self): if self.throughput: @@ -576,41 +580,35 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None """ benchmark = Benchmark_litGPT(**kwargs) - try: - benchmark.train() + benchmark.train() - if global_rank in [0, None]: - benchmark.add_perf_metrics() + if global_rank in [0, None]: + benchmark.add_perf_metrics() - print( - f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.config.block_size}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}" - ) - print( - f"Number of Layers: {benchmark.config.n_layer}\nNumber of parameters: {sum(p.numel() for p in benchmark.model.parameters() if p.requires_grad) / 1e9:.02f}B" - ) - print(f"Distributed Mode: {benchmark.distributed_mode}") - if benchmark.distributed_mode == "fsdp": - print(f"Sharding Mode: {benchmark.shard_mode}\nBucketing: {benchmark.bucketing_mode}") - if benchmark.sharding_size is not None: - print( - f"Sharding Size: {benchmark.sharding_size}\nReplicate DP Groups: {int(world_size/benchmark.sharding_size)}" - ) - if benchmark.bucketing_mode == "size": - print(f"Bucketing Number Params: {benchmark.fsdp_bucket_params}") - elif benchmark.distributed_mode == "ddp": - print(f"DDP Bucketing Size: {benchmark.ddp_bucket_size} MB") - print(f"Compiler: {benchmark.compile}") - print(f"Low Precision Mode: {benchmark.low_precision_mode}") - print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms") - print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB") - print(f"Tokens/s: {benchmark.perf_metrics['tokens_per_sec']:.02f}") - print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}") - print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}") - - except Exception as error: - # Helps catch OutOfMemory Errors and post processing of errors - if global_rank in [0, None]: - print("An error occurred:", type(error).__name__, "–", error) + print( + f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.config.block_size}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}" + ) + print( + f"Number of Layers: {benchmark.config.n_layer}\nNumber of parameters: {sum(p.numel() for p in benchmark.model.parameters() if p.requires_grad) / 1e9:.02f}B" + ) + print(f"Distributed Mode: {benchmark.distributed_mode}") + if benchmark.distributed_mode == "fsdp": + print(f"Sharding Mode: {benchmark.shard_mode}\nBucketing: {benchmark.bucketing_mode}") + if benchmark.sharding_size is not None: + print( + f"Sharding Size: {benchmark.sharding_size}\nReplicate DP Groups: {int(world_size/benchmark.sharding_size)}" + ) + if benchmark.bucketing_mode == "size": + print(f"Bucketing Number Params: {benchmark.fsdp_bucket_params}") + elif benchmark.distributed_mode == "ddp": + print(f"DDP Bucketing Size: {benchmark.ddp_bucket_size} MB") + print(f"Compiler: {benchmark.compile}") + print(f"Low Precision Mode: {benchmark.low_precision_mode}") + print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms") + print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB") + print(f"Tokens/s: {benchmark.perf_metrics['tokens_per_sec']:.02f}") + print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}") + print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}") if global_rank in [0, None]: if return_metrics_as_json: