Skip to content

Commit

Permalink
Enable gradient accumulation in litgpt benchmark
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Mar 25, 2024
1 parent fbffc47 commit 2e3d5e2
Showing 1 changed file with 47 additions and 36 deletions.
83 changes: 47 additions & 36 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
from typing import Any

import torch
import functools
Expand Down Expand Up @@ -34,6 +35,22 @@ def configure_optimizers(model, weight_decay, learning_rate, betas, device_type)
return optimizer


def run_fwd_bwd_one_microbatch(
model: torch.nn.Module,
input_ids: torch.Tensor,
targets: torch.Tensor,
gradient_accumulation_steps: int,
te_ctx: Any,
) -> torch.Tensor:
with te_ctx():
logits = model(input_ids)
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
loss = torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1) / gradient_accumulation_steps
loss.backward()
return loss


class Benchmark_litGPT:
def __init__(
self,
Expand All @@ -50,6 +67,7 @@ def __init__(
n_layers: int | None = None,
profiler_start: int = 15,
profiler_stop: int = 15,
skip_data_sync: bool = False,
):
seed = 1337
torch.manual_seed(seed)
Expand Down Expand Up @@ -90,11 +108,7 @@ def __init__(
assert (
self.global_batch_size % self.micro_batch_size * world_size == 0
), f"Global Batch Size {self.global_batch_size} should be a multiple Micro Batch Size {self.micro_batch_size} * World Size {world_size}."
# TODO: Remove when gradient accumulation is ready for benchmarking.
if self.gradient_accumulation_steps > 1:
print(
f"[WARNING] Gradient Accumulation is not fully supported yet. Benchmarking results may not be accurate. Gradient Accumulation Steps = {self.gradient_accumulation_steps}"
)
self.skip_data_sync = skip_data_sync

# Profiling Args
self.nsys_enabled = nsys_enabled
Expand Down Expand Up @@ -280,46 +294,43 @@ def train(self):

te_ctx = nullcontext

if self.skip_data_sync:
data_sync_ctx = self.model.no_sync
else:
data_sync_ctx = nullcontext

for i in range(self.max_iters):
iter_t0 = time.perf_counter()
if i == self.warmup_iter: # warmup
t0 = iter_t0

for step_idx in range(self.gradient_accumulation_steps):
input_ids, targets = next(self.train_data_iter)
input_ids = input_ids.to(device=self.device)
targets = targets.to(device=self.device)
with data_sync_ctx():
for step_idx in range(self.gradient_accumulation_steps - 1):
input_ids, targets = next(self.train_data_iter)
input_ids = input_ids.to(device=self.device)
targets = targets.to(device=self.device)

if self.nsys_enabled and i == self.profiler_start and global_rank in [0, None] and step_idx == 0:
print("=====Start NSYS Profiling======")
torch.cuda.cudart().cudaProfilerStart()

if self.nsys_enabled and i == self.profiler_start and global_rank in [0, None] and step_idx == 0:
print("=====Start NSYS Profiling======")
torch.cuda.cudart().cudaProfilerStart()
loss = run_fwd_bwd_one_microbatch(
self.model, input_ids, targets, self.gradient_accumulation_steps, te_ctx
)

with te_ctx():
logits = self.model(input_ids)
input_ids, targets = next(self.train_data_iter)
input_ids = input_ids.to(device=self.device)
targets = targets.to(device=self.device)
loss = run_fwd_bwd_one_microbatch(self.model, input_ids, targets, self.gradient_accumulation_steps, te_ctx)

logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
loss = (
torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
/ self.gradient_accumulation_steps
)
# Simple Gradient Accumulation Implementation
if (step_idx + 1) % self.gradient_accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)

loss.backward()

# Simple Gradient Accumulation Implementation
if (step_idx + 1) % self.gradient_accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)

# torch.cuda.synchronize()
if (
self.nsys_enabled
and i == self.profiler_stop
and global_rank in [0, None]
and ((step_idx + 1) % self.gradient_accumulation_steps == 0)
):
print("=====Stop NSYS Profiling======")
torch.cuda.cudart().cudaProfilerStop()
if self.nsys_enabled and i == self.profiler_stop and global_rank in [0, None]:
print("=====Stop NSYS Profiling======")
torch.cuda.cudart().cudaProfilerStop()

loss_item = loss.item() # synchronization
t1 = time.perf_counter()
Expand Down

0 comments on commit 2e3d5e2

Please sign in to comment.