From 4596a67e492746bfaed9e0927079553bb35d809f Mon Sep 17 00:00:00 2001 From: parthmannan Date: Wed, 13 Mar 2024 06:41:03 +0000 Subject: [PATCH] Using median iteration time to avoid outliers skewing data --- examples/lit-gpt/test_parametrized.py | 12 ++++++------ thunder/benchmarks/benchmark_litgpt.py | 16 ++++++++++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/lit-gpt/test_parametrized.py b/examples/lit-gpt/test_parametrized.py index 20ddaa9278..a774957a25 100644 --- a/examples/lit-gpt/test_parametrized.py +++ b/examples/lit-gpt/test_parametrized.py @@ -57,7 +57,7 @@ def complete_dataframe(self, is_teardown): df['Sharding Size'] = df['Sharding Size'].fillna('none') #Convert None Type to string so that pivot table can group. index_list = ['model_name', 'Num GPUS', 'Seq Len', 'Micro BS', 'Global BS', 'GA', 'Distributed Mode', 'Sharding Size'] - self.iter_time_df = df.pivot_table(index=index_list, columns='compiler', values='average_iter_time', aggfunc='first').reset_index() + self.iter_time_df = df.pivot_table(index=index_list, columns='compiler', values='median_iter_time', aggfunc='first').reset_index() self.tokens_per_sec_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec', aggfunc='first').reset_index() self.tokens_per_sec_per_gpu_df = df.pivot_table(index=index_list, columns='compiler', values='tokens_per_sec_per_gpu', aggfunc='first').reset_index() self.memory_used_GB_df = df.pivot_table(index=index_list, columns='compiler', values='memory_used_GB', aggfunc='first').reset_index() @@ -72,12 +72,12 @@ def complete_dataframe(self, is_teardown): filename = '/scratch/lightning-thunder/examples/lit-gpt/' + str(filename) with pd.ExcelWriter(filename, engine='xlsxwriter') as writer: - self.iter_time_df.to_excel(writer, sheet_name='Average Iter Time (ms)') + self.iter_time_df.to_excel(writer, sheet_name='Median Iter Time (ms)') self.tokens_per_sec_df.to_excel(writer, sheet_name='Tokens per sec') self.tokens_per_sec_per_gpu_df.to_excel(writer, sheet_name='Tokens per sec per GPU') self.memory_used_GB_df.to_excel(writer, sheet_name='Memory allocated GB') elif self.output_format == 'print': - print("\nAVERAGE ITERATION TIME (ms)") + print("\nMEDIAN ITERATION TIME (ms)") print(self.iter_time_df) print("\nTHROUGHPUT (tokens/s)") print(self.tokens_per_sec_df) @@ -105,9 +105,9 @@ def run_benchmark(self, kwargs): self.perf_metrics_dict = json.load(file) os.remove(self.json_file_path) #cleanup after test finishes - if self.perf_metrics_dict['average_iter_time'] is None: + if self.perf_metrics_dict['median_iter_time'] is None: if 'CUDA out of memory' in proc_output.stdout: - self.perf_metrics_dict['average_iter_time'] = 'OOM' + self.perf_metrics_dict['median_iter_time'] = 'OOM' self.perf_metrics_dict['model_flops'] = 'OOM' self.perf_metrics_dict['model_flop_per_sec'] = 'OOM' self.perf_metrics_dict['tokens_per_sec'] = 'OOM' @@ -188,7 +188,7 @@ def tearDownClass(cls): shard_mode = ("zero2", ), model_name = ("Llama-2-7b-hf", ), micro_batch_size = (1, 4, ), - compile = ("eager", "inductor", "thunder", "thunder_inductor",) + compile = ("eager", "inductor", "thunder", "thunder_inductor", "thunder_inductor_transformerengine") ) def test(self, **kwargs): diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 3bad00892d..f68dd1f0fa 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -2,6 +2,7 @@ import copy import time import pprint +import numpy as np import torch import functools @@ -282,8 +283,10 @@ def train(self): for i in range(self.max_iters): iter_t0 = time.perf_counter() + iter_times = [] if i == self.warmup_iter: # warmup t0 = iter_t0 + iter_times = [] #reset the iter time list for step_idx in range(self.gradient_accumulation_steps): input_ids, targets = next(self.train_data_iter) @@ -324,9 +327,11 @@ def train(self): loss_item = loss.item() # synchronization t1 = time.perf_counter() if global_rank in [0, None]: + iter_time = (t1 - iter_t0) * 1000 print( - f"iter {i}: loss {loss_item:.4f}, iter time: {(t1 - iter_t0) * 1000:.2f}ms, t: {input_ids.size(1)}" + f"iter {i}: loss {loss_item:.4f}, iter time: {iter_time:.2f}ms, t: {input_ids.size(1)}" ) + iter_times.append(iter_time) # if global_rank in [0, None] and i >=warmup_iter: # self.throughput.update( @@ -342,20 +347,19 @@ def train(self): # print(metrics) if global_rank in [0, None]: - # print(f"Total time: {(t1 - t0):.2f}s") - # print(f"Average time per iter: {((t1 - t0)*1000)/(max_iters-warmup_iter):.2f}ms") self.perf_metrics["average_iter_time"] = ((t1 - t0) * 1000) / (self.max_iters - self.warmup_iter) + self.perf_metrics["median_iter_time"] = np.median(iter_times) #To avoid outliers def add_perf_metrics(self): # tokens_per_sec = total number of benchmarked iterations x global BS x block_size / total elapsed time (s) # = global BS x block_size / (total elapsed time (s)/total number of benchmarked iterations) # = global BS x block_size / average iter time (s) self.perf_metrics["tokens_per_sec"] = ( - self.global_batch_size * self.model.max_seq_length * 1000 / self.perf_metrics["average_iter_time"] + self.global_batch_size * self.model.max_seq_length * 1000 / self.perf_metrics["median_iter_time"] ) # tokens/s if self.perf_metrics["model_flops"] is not None: self.perf_metrics["model_flop_per_sec"] = ( - self.perf_metrics["model_flops"] * 1000 / self.perf_metrics["average_iter_time"] + self.perf_metrics["model_flops"] * 1000 / self.perf_metrics["median_iter_time"] ) if world_size is not None: self.perf_metrics["model_flop_per_sec"] *= world_size @@ -428,7 +432,7 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None f"Sharding Mode: {benchmark.shard_mode}\nSharding Size: {benchmark.sharding_size}\nBucketing: {benchmark.bucketing_mode}" ) print(f"Compiler: {benchmark.compile}") - print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms") + print(f"Median iter time: {benchmark.perf_metrics['median_iter_time']:.2f} ms") print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB") print(f"Throughput (Tokens/s): {benchmark.perf_metrics['tokens_per_sec']:.02f} tokens/s") print(