Skip to content

Commit

Permalink
Using median iteration time to avoid outliers skewing data
Browse files Browse the repository at this point in the history
  • Loading branch information
parthmannan committed Mar 13, 2024
1 parent a05ada7 commit 4596a67
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
12 changes: 6 additions & 6 deletions examples/lit-gpt/test_parametrized.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def complete_dataframe(self, is_teardown):
df['Sharding Size'] = df['Sharding Size'].fillna('none') #Convert None Type to string so that pivot table can group.
index_list = ['model_name', 'Num GPUS', 'Seq Len', 'Micro BS', 'Global BS', 'GA', 'Distributed Mode', 'Sharding Size']

self.iter_time_df = df.pivot_table(index=index_list, columns='compiler', values='average_iter_time', aggfunc='first').reset_index()
self.iter_time_df = df.pivot_table(index=index_list, columns='compiler', values='median_iter_time', aggfunc='first').reset_index()
self.tokens_per_sec_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec', aggfunc='first').reset_index()
self.tokens_per_sec_per_gpu_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec_per_gpu', aggfunc='first').reset_index()
self.memory_used_GB_df = df.pivot_table(index=index_list, columns='compiler', values='memory_used_GB', aggfunc='first').reset_index()
Expand All @@ -72,12 +72,12 @@ def complete_dataframe(self, is_teardown):
filename = '/scratch/lightning-thunder/examples/lit-gpt/' + str(filename)

with pd.ExcelWriter(filename, engine='xlsxwriter') as writer:
self.iter_time_df.to_excel(writer, sheet_name='Average Iter Time (ms)')
self.iter_time_df.to_excel(writer, sheet_name='Median Iter Time (ms)')
self.tokens_per_sec_df.to_excel(writer, sheet_name='Tokens per sec')
self.tokens_per_sec_per_gpu_df.to_excel(writer, sheet_name='Tokens per sec per GPU')
self.memory_used_GB_df.to_excel(writer, sheet_name='Memory allocated GB')
elif self.output_format == 'print':
print("\nAVERAGE ITERATION TIME (ms)")
print("\nMEDIAN ITERATION TIME (ms)")
print(self.iter_time_df)
print("\nTHROUGHPUT (tokens/s)")
print(self.tokens_per_sec_df)
Expand Down Expand Up @@ -105,9 +105,9 @@ def run_benchmark(self, kwargs):
self.perf_metrics_dict = json.load(file)
os.remove(self.json_file_path) #cleanup after test finishes

if self.perf_metrics_dict['average_iter_time'] is None:
if self.perf_metrics_dict['median_iter_time'] is None:
if 'CUDA out of memory' in proc_output.stdout:
self.perf_metrics_dict['average_iter_time'] = 'OOM'
self.perf_metrics_dict['median_iter_time'] = 'OOM'
self.perf_metrics_dict['model_flops'] = 'OOM'
self.perf_metrics_dict['model_flop_per_sec'] = 'OOM'
self.perf_metrics_dict['tokens_per_sec'] = 'OOM'
Expand Down Expand Up @@ -188,7 +188,7 @@ def tearDownClass(cls):
shard_mode = ("zero2", ),
model_name = ("Llama-2-7b-hf", ),
micro_batch_size = (1, 4, ),
compile = ("eager", "inductor", "thunder", "thunder_inductor",)
compile = ("eager", "inductor", "thunder", "thunder_inductor", "thunder_inductor_transformerengine")
)

def test(self, **kwargs):
Expand Down
16 changes: 10 additions & 6 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import time
import pprint
import numpy as np

import torch
import functools
Expand Down Expand Up @@ -282,8 +283,10 @@ def train(self):

for i in range(self.max_iters):
iter_t0 = time.perf_counter()
iter_times = []
if i == self.warmup_iter: # warmup
t0 = iter_t0
iter_times = [] #reset the iter time list

for step_idx in range(self.gradient_accumulation_steps):
input_ids, targets = next(self.train_data_iter)
Expand Down Expand Up @@ -324,9 +327,11 @@ def train(self):
loss_item = loss.item() # synchronization
t1 = time.perf_counter()
if global_rank in [0, None]:
iter_time = (t1 - iter_t0) * 1000
print(
f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}"
f"iter {i}: loss {loss_item:.4f}, iter time: {iter_time:.2f}ms, t: {input_ids.size(1)}"
)
iter_times.append(iter_time)

# if global_rank in [0, None] and i >=warmup_iter:
# self.throughput.update(
Expand All @@ -342,20 +347,19 @@ def train(self):
# print(metrics)

if global_rank in [0, None]:
# print(f"Total time: {(t1 - t0):.2f}s")
# print(f"Average time per iter: {((t1 - t0)*1000)/(max_iters-warmup_iter):.2f}ms")
self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iter)
self.perf_metrics["median_iter_time"] = np.median(iter_times) #To avoid outliers

def add_perf_metrics(self):
# tokens_per_sec = total number of benchmarked iterations x global BS x block_size / total elapsed time (s)
# = global BS x block_size / (total elapsed time (s)/total number of benchmarked iterations)
# = global BS x block_size / average iter time (s)
self.perf_metrics["tokens_per_sec"] = (
self.global_batch_size * self.model.max_seq_length * 1000 / self.perf_metrics["average_iter_time"]
self.global_batch_size * self.model.max_seq_length * 1000 / self.perf_metrics["median_iter_time"]
) # tokens/s
if self.perf_metrics["model_flops"] is not None:
self.perf_metrics["model_flop_per_sec"] = (
self.perf_metrics["model_flops"] * 1000 / self.perf_metrics["average_iter_time"]
self.perf_metrics["model_flops"] * 1000 / self.perf_metrics["median_iter_time"]
)
if world_size is not None:
self.perf_metrics["model_flop_per_sec"] *= world_size
Expand Down Expand Up @@ -428,7 +432,7 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None
f"Sharding Mode: {benchmark.shard_mode}\nSharding Size: {benchmark.sharding_size}\nBucketing: {benchmark.bucketing_mode}"
)
print(f"Compiler: {benchmark.compile}")
print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms")
print(f"Median iter time: {benchmark.perf_metrics['median_iter_time']:.2f} ms")
print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB")
print(f"Throughput (Tokens/s): {benchmark.perf_metrics['tokens_per_sec']:.02f} tokens/s")
print(
Expand Down

0 comments on commit 4596a67

Please sign in to comment.