Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions examples/lit-gpt/test_parametrized.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
MID_BENCHMARK_OUT - use this env variable to control whether you want to see the combined results
between each test.
BENCHMARK_OUT_FORMAT - use this env variable to control the format in which the results are presented.
Uses 'xlsx' by default. More format support to come soon.
Uses 'xlsx' by default. Supported: 'none', 'print', 'xlsx'.
'''

import torch
from absl.testing import parameterized
from absl.testing import absltest
from collections import defaultdict
import os
import subprocess
import json
Expand Down Expand Up @@ -48,6 +49,9 @@ def add_to_dataframe(self):
self.dataframe_data.append(self.perf_metrics_dict)

def complete_dataframe(self, is_teardown):
if not self.dataframe_data:
# The benchmark probably failed
return
#Called when tearing down the parametrized test
#This generates a summarized dataframe for each perf metric and saves as a xlsx file
df = pd.DataFrame(self.dataframe_data)
Expand All @@ -59,7 +63,7 @@ def complete_dataframe(self, is_teardown):
self.tokens_per_sec_per_gpu_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec_per_gpu', aggfunc='first').reset_index()
self.memory_used_GB_df = df.pivot_table(index=index_list, columns='compiler', values='memory_used_GB', aggfunc='first').reset_index()

if self.output_format not in ('none', 'print'):
if self.output_format == "xlsx":
output_ext = {'xlsx': '.xlsx', }[self.output_format]
if not is_teardown:
filename = 'examples/lit-gpt/mid_output_parameterized_results' + str(output_ext)
Expand All @@ -84,7 +88,6 @@ def complete_dataframe(self, is_teardown):
print(self.memory_used_GB_df)

def run_benchmark(self, kwargs):
# benchmark_file = 'thunder/benchmarks/benchmark_litgpt.py'
command_list = []
for key, val in kwargs.items():
command_list.append("--" + str(key) + "=" + str(val))
Expand All @@ -98,32 +101,26 @@ def run_benchmark(self, kwargs):

print(f'Running {" ".join(subprocess_cmd)!r}')
proc_output = subprocess.run(subprocess_cmd, capture_output=True, text=True)

self.perf_metrics_dict = {}
if os.path.exists(self.json_file_path):
with open(self.json_file_path, 'r') as file:
self.perf_metrics_dict = json.load(file)
# Cleanup after the benchmark finishes. It might have failed before creating this
os.remove(self.json_file_path)

if proc_output.returncode:
print(proc_output.stdout)
print(proc_output.stderr)
proc_output.check_returncode()

with open(self.json_file_path, 'r') as file:
self.perf_metrics_dict = json.load(file)
os.remove(self.json_file_path) #cleanup after test finishes

if self.perf_metrics_dict['average_iter_time'] is None:
if 'CUDA out of memory' in proc_output.stdout:
self.perf_metrics_dict['average_iter_time'] = 'OOM'
self.perf_metrics_dict['model_flops'] = 'OOM'
self.perf_metrics_dict['model_flop_per_sec'] = 'OOM'
self.perf_metrics_dict['tokens_per_sec'] = 'OOM'
self.perf_metrics_dict['tokens_per_sec_per_gpu'] = 'OOM'
self.perf_metrics_dict['memory_used_GB'] = 'OOM'
if 'CUDA out of memory' in proc_output.stdout or "CUDA error: out of memory" in proc_output.stderr:
defaultdict_oom = defaultdict(lambda: "OOM")
defaultdict_oom.update(self.perf_metrics_dict)
self.perf_metrics_dict = defaultdict_oom
pass_str = "TestCase did not finish reporting metrics due to CUDA out of memory error. Reporting OOM and triggering test success."
return True, pass_str
else:
print(proc_output.stdout)
print(proc_output.stderr)
fail_str = "Testcase did not finish reporting metrics due to an unknown error. Triggering test failure."
return False, fail_str
else:
return True, "Test passed successfully."
print(proc_output.stdout)
print(proc_output.stderr)
fail_str = "TestCase did not finish reporting metrics due to an unknown error. Triggering test failure."
return False, fail_str
return True, "Test passed successfully."


class Test(parameterized.TestCase):
Expand Down
99 changes: 41 additions & 58 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,9 @@
import thunder
from thunder.tests.lit_gpt_model import Config, GPT, Block

try:
from lightning.fabric.utilities.throughput import measure_flops
from lightning.fabric.utilities.throughput import measure_flops
from lightning.fabric.utilities import Throughput

# from lightning.fabric.utilities import Throughput
LIGHTNING_AVAILABLE = True
except:
LIGHTNING_AVAILABLE = False

world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
Expand Down Expand Up @@ -109,7 +105,10 @@ def __init__(
self.config.n_layer = n_layers

# Initialize the model
t0 = time.perf_counter()
print(f"Loading model with {self.config.__dict__}")
self.model = self.init_model()
print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")

# Setup the distributed algorithm choices
if self.distributed_mode != "none":
Expand Down Expand Up @@ -138,14 +137,10 @@ def __init__(
}

def init_model(self):
print(f"Loading model with {self.config.__dict__}")
init_device = torch.device("meta") if self.distributed_mode == "fsdp" else self.device
t0 = time.perf_counter()
with init_device:
model = GPT(self.config)
model.to(dtype=torch.bfloat16)
print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")

model.to(dtype=torch.bfloat16)
return model

def setup_distributed(self):
Expand Down Expand Up @@ -243,32 +238,38 @@ def pad_collate(batch):
y_padded = torch.nn.utils.rnn.pad_sequence(y, batch_first=True, padding_value=-1)
return x_padded, y_padded

train_data = DummyDataset(self.model.max_seq_length, self.dynamic)
train_data = DummyDataset(self.config.block_size, self.dynamic)
train_dataloader = DataLoader(
train_data, batch_size=self.micro_batch_size, num_workers=2, collate_fn=pad_collate
)

return train_dataloader

def calculate_model_flops(self):
input_ids, targets = next(self.train_data_iter)
input_ids = input_ids.to(device=self.device)
targets = targets.to(device=self.device)
meta = torch.device("meta")
device = self.device
self.device = meta

# calculate flops on a meta-device model because we only care about the shapes and
# because the flops calculator installs hooks on the model
meta_model = self.init_model()

model_fwd = lambda: self.model(input_ids)
x = torch.randint(0, 1, (self.micro_batch_size, meta_model.config.block_size), device=meta)
model_fwd = lambda: meta_model(x)
model_loss = lambda y: torch.nn.functional.cross_entropy(
y.reshape(-1, y.size(-1)), targets.reshape(-1), ignore_index=-1
y.reshape(-1, y.size(-1)), x.reshape(-1), ignore_index=-1
)
if LIGHTNING_AVAILABLE:
self.perf_metrics["model_flops"] = measure_flops(self.model, model_fwd, model_loss) / 1e12
self.perf_metrics["model_flops"] = measure_flops(meta_model, model_fwd, model_loss)

self.device = device

def train(self):
t0 = None
# if global_rank in [0, None]:
# #Calculate the model FLOPs
# self.calculate_model_flops()
# Setup Perf Collection
# self.throughput = Throughput(window_size=10, world_size=world_size)
if global_rank in [0, None]:
# Calculate the model FLOPs
self.calculate_model_flops()
# Setup throughput Collection
self.throughput = Throughput(window_size=self.max_iters - self.warmup_iter, world_size=world_size)

if "transformerengine" in self.compile:
import transformer_engine.pytorch as te
Expand Down Expand Up @@ -326,45 +327,30 @@ def train(self):
print(
f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}"
)

# if global_rank in [0, None] and i >=warmup_iter:
# self.throughput.update(
# time=(t1-t0),
# flops=self.model_flops,
# batches=i,
# samples=(i * self.micro_batch_size * self.gradient_accumulation_steps),
# lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.model.max_seq_length),
# )

# metrics = self.throughput.compute()
# if i % 10 == 0:
# print(metrics)
if i >= self.warmup_iter:
self.throughput.update(
time=(t1 - t0),
flops=self.perf_metrics["model_flops"],
batches=i,
samples=(i * self.micro_batch_size * self.gradient_accumulation_steps),
lengths=(i * self.micro_batch_size * self.gradient_accumulation_steps * self.config.block_size),
)

if global_rank in [0, None]:
# print(f"Total time: {(t1 - t0):.2f}s")
# print(f"Average time per iter: {((t1 - t0)*1000)/(max_iters-warmup_iter):.2f}ms")
self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iter)

def add_perf_metrics(self):
# tokens_per_sec = total number of benchmarked iterations x global BS x block_size / total elapsed time (s)
# = global BS x block_size / (total elapsed time (s)/total number of benchmarked iterations)
# = global BS x block_size / average iter time (s)
self.perf_metrics["tokens_per_sec"] = (
self.global_batch_size * self.model.max_seq_length * 1000 / self.perf_metrics["average_iter_time"]
) # tokens/s
if self.perf_metrics["model_flops"] is not None:
self.perf_metrics["model_flop_per_sec"] = (
self.perf_metrics["model_flops"] * 1000 / self.perf_metrics["average_iter_time"]
)
if world_size is not None:
self.perf_metrics["model_flop_per_sec"] *= world_size
metrics = self.throughput.compute()
self.perf_metrics["tokens_per_sec"] = metrics.get("items_per_sec", metrics["device/items_per_sec"])
self.perf_metrics["model_flop_per_sec"] = metrics.get("flops_per_sec", metrics["device/flops_per_sec"])
self.perf_metrics["memory_used_GB"] = torch.cuda.max_memory_allocated() / 1e9

def add_model_info_to_metrics(self):
if global_rank in [0, None]:
self.perf_metrics["model_name"] = self.model_name
self.perf_metrics["Num GPUS"] = world_size
self.perf_metrics["Seq Len"] = self.model.max_seq_length
self.perf_metrics["Seq Len"] = self.config.block_size
self.perf_metrics["Micro BS"] = self.micro_batch_size
self.perf_metrics["Global BS"] = self.global_batch_size
self.perf_metrics["GA"] = self.gradient_accumulation_steps
Expand Down Expand Up @@ -416,7 +402,7 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
benchmark.add_perf_metrics()

print(
f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.model.max_seq_length}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}"
f"Model name: {benchmark.model_name}\nSeq Length: {benchmark.config.block_size}\nMicro BS: {benchmark.micro_batch_size}\nGlobal BS: {benchmark.global_batch_size}"
)
print(
f"Number of Layers: {benchmark.config.n_layer}\nNumber of parameters: {sum(p.numel() for p in benchmark.model.parameters() if p.requires_grad) / 1e9:.02f}B"
Expand All @@ -429,12 +415,9 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
print(f"Compiler: {benchmark.compile}")
print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms")
print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB")
print(f"Throughput (Tokens/s): {benchmark.perf_metrics['tokens_per_sec']:.02f} tokens/s")
print(
f"Normalized Throughput (Tokens/s/GPU): {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f} tokens/s/gpu"
)
if benchmark.perf_metrics["model_flop_per_sec"] is not None:
print(f"Model TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec']:.02f} TFLOP/s")
print(f"Tokens/s: {benchmark.perf_metrics['tokens_per_sec']:.02f}")
print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}")
print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}")

except Exception as error:
# Helps catch OutOfMemory Errors and post processing of errors
Expand Down