From 2e3d5e28e1a898b898f9a9829a0a6227c7654a21 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 22 Mar 2024 14:06:15 +0900 Subject: [PATCH] Enable gradient accumulation in litgpt benchmark Signed-off-by: Masaki Kozuki --- thunder/benchmarks/benchmark_litgpt.py | 83 +++++++++++++++----------- 1 file changed, 47 insertions(+), 36 deletions(-) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 877f99f38e..c73a6c9ae0 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -1,5 +1,6 @@ import os import time +from typing import Any import torch import functools @@ -34,6 +35,22 @@ def configure_optimizers(model, weight_decay, learning_rate, betas, device_type) return optimizer +def run_fwd_bwd_one_microbatch( + model: torch.nn.Module, + input_ids: torch.Tensor, + targets: torch.Tensor, + gradient_accumulation_steps: int, + te_ctx: Any, +) -> torch.Tensor: + with te_ctx(): + logits = model(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = targets.reshape(-1) + loss = torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) / gradient_accumulation_steps + loss.backward() + return loss + + class Benchmark_litGPT: def __init__( self, @@ -50,6 +67,7 @@ def __init__( n_layers: int | None = None, profiler_start: int = 15, profiler_stop: int = 15, + skip_data_sync: bool = False, ): seed = 1337 torch.manual_seed(seed) @@ -90,11 +108,7 @@ def __init__( assert ( self.global_batch_size % self.micro_batch_size * world_size == 0 ), f"Global Batch Size {self.global_batch_size} should be a multiple Micro Batch Size {self.micro_batch_size} * World Size {world_size}." - # TODO: Remove when gradient accumulation is ready for benchmarking. - if self.gradient_accumulation_steps > 1: - print( - f"[WARNING] Gradient Accumulation is not fully supported yet. Benchmarking results may not be accurate. Gradient Accumulation Steps = {self.gradient_accumulation_steps}" - ) + self.skip_data_sync = skip_data_sync # Profiling Args self.nsys_enabled = nsys_enabled @@ -280,46 +294,43 @@ def train(self): te_ctx = nullcontext + if self.skip_data_sync: + data_sync_ctx = self.model.no_sync + else: + data_sync_ctx = nullcontext + for i in range(self.max_iters): iter_t0 = time.perf_counter() if i == self.warmup_iter: # warmup t0 = iter_t0 - for step_idx in range(self.gradient_accumulation_steps): - input_ids, targets = next(self.train_data_iter) - input_ids = input_ids.to(device=self.device) - targets = targets.to(device=self.device) + with data_sync_ctx(): + for step_idx in range(self.gradient_accumulation_steps - 1): + input_ids, targets = next(self.train_data_iter) + input_ids = input_ids.to(device=self.device) + targets = targets.to(device=self.device) + + if self.nsys_enabled and i == self.profiler_start and global_rank in [0, None] and step_idx == 0: + print("=====Start NSYS Profiling======") + torch.cuda.cudart().cudaProfilerStart() - if self.nsys_enabled and i == self.profiler_start and global_rank in [0, None] and step_idx == 0: - print("=====Start NSYS Profiling======") - torch.cuda.cudart().cudaProfilerStart() + loss = run_fwd_bwd_one_microbatch( + self.model, input_ids, targets, self.gradient_accumulation_steps, te_ctx + ) - with te_ctx(): - logits = self.model(input_ids) + input_ids, targets = next(self.train_data_iter) + input_ids = input_ids.to(device=self.device) + targets = targets.to(device=self.device) + loss = run_fwd_bwd_one_microbatch(self.model, input_ids, targets, self.gradient_accumulation_steps, te_ctx) - logits = logits.reshape(-1, logits.size(-1)) - targets = targets.reshape(-1) - loss = ( - torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) - / self.gradient_accumulation_steps - ) + # Simple Gradient Accumulation Implementation + if (step_idx + 1) % self.gradient_accumulation_steps == 0: + self.optimizer.step() + self.optimizer.zero_grad(set_to_none=True) - loss.backward() - - # Simple Gradient Accumulation Implementation - if (step_idx + 1) % self.gradient_accumulation_steps == 0: - self.optimizer.step() - self.optimizer.zero_grad(set_to_none=True) - - # torch.cuda.synchronize() - if ( - self.nsys_enabled - and i == self.profiler_stop - and global_rank in [0, None] - and ((step_idx + 1) % self.gradient_accumulation_steps == 0) - ): - print("=====Stop NSYS Profiling======") - torch.cuda.cudart().cudaProfilerStop() + if self.nsys_enabled and i == self.profiler_stop and global_rank in [0, None]: + print("=====Stop NSYS Profiling======") + torch.cuda.cudart().cudaProfilerStop() loss_item = loss.item() # synchronization t1 = time.perf_counter()