Skip to content

Commit 4b7e20b

Browse files
authored
Merge pull request #66 from epfLLM/tokens_per_second
Tokens per second metric
2 parents d7e3d04 + 5b4ae47 commit 4b7e20b

File tree

5 files changed

+42
-5
lines changed

5 files changed

+42
-5
lines changed

finetune.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
import torch
77

8-
from megatron import get_args, get_tokenizer, get_timers, print_rank_0
8+
from megatron import get_args, get_tokenizer, get_timers, get_counters, print_rank_0
99
from megatron.training import pretrain
1010
from megatron.core import tensor_parallel
11+
from megatron.core.parallel_state import get_data_parallel_group
1112
from megatron.model import GPTModel, ModelType, LlamaModel, FalconModel
1213
from megatron.utils import get_ltor_masks_and_position_ids, average_losses_across_data_parallel_group
1314
from megatron.data.gpt_dataset import build_train_valid_test_datasets as gpt_build_datasets
@@ -119,8 +120,21 @@ def get_batch(data_iterator):
119120
tokens = data_b["text"]
120121
labels = tokens[:, 1:].contiguous()
121122
tokens = tokens[:, :-1].contiguous()
122-
if args.data_type == "gpt":
123123

124+
# Update tokens counter.
125+
counters = get_counters()
126+
n_tokens = torch.tensor(tokens.numel(), device=tokens.device)
127+
if args.data_parallel_size == 1:
128+
n_tokens = n_tokens.item()
129+
else:
130+
group = get_data_parallel_group()
131+
torch.distributed.all_reduce(
132+
n_tokens, op=torch.distributed.ReduceOp.SUM, group=group
133+
)
134+
n_tokens = n_tokens.item()
135+
counters["tokens"] += n_tokens
136+
137+
if args.data_type == "gpt":
124138
# Get the masks and position ids.
125139
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
126140
tokens,

megatron/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .global_vars import get_tensorboard_writer
1111
from .global_vars import get_adlr_autoresume
1212
from .global_vars import get_timers
13+
from .global_vars import get_counters
1314

1415
from .utils import (print_rank_0,
1516
print_all_nodes,

megatron/global_vars.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import os
66
import sys
7+
from collections import defaultdict
78

89
from megatron import dist_signal_handler
910
from megatron.tokenizer import build_tokenizer
@@ -17,6 +18,7 @@
1718
_GLOBAL_ADLR_AUTORESUME = None
1819
_GLOBAL_TIMERS = None
1920
_GLOBAL_SIGNAL_HANDLER = None
21+
_GLOBAL_COUNTERS = None
2022

2123

2224
def get_args():
@@ -62,6 +64,12 @@ def get_timers():
6264
return _GLOBAL_TIMERS
6365

6466

67+
def get_counters():
68+
"""Return counters."""
69+
_ensure_var_is_initialized(_GLOBAL_COUNTERS, 'counters')
70+
return _GLOBAL_COUNTERS
71+
72+
6573
def get_signal_handler():
6674
_ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
6775
return _GLOBAL_SIGNAL_HANDLER
@@ -90,6 +98,7 @@ def set_global_variables(args):
9098
_set_tensorboard_writer(args)
9199
_set_adlr_autoresume(args)
92100
_set_timers(args)
101+
_set_counters(args)
93102

94103
if args.exit_signal_handler:
95104
_set_signal_handler()
@@ -178,6 +187,12 @@ def _set_timers(args):
178187
_GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)
179188

180189

190+
def _set_counters(args):
191+
global _GLOBAL_COUNTERS
192+
_ensure_var_is_not_initialized(_GLOBAL_COUNTERS, 'counters')
193+
_GLOBAL_COUNTERS = defaultdict(int)
194+
195+
181196
def _ensure_var_is_initialized(var, name):
182197
"""Make sure the input variable is not None."""
183198
assert var is not None, '{} is not initialized.'.format(name)

megatron/timers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,3 @@ def write(self, names, writer, iteration, normalizer=1.0,
302302
for name in name_to_min_max_time:
303303
_, max_time = name_to_min_max_time[name]
304304
writer.add_scalar(name + '-time', max_time, iteration)
305-
# if using wandb writer, flush the stats we just filled here, close to the creation time
306-
if hasattr(writer,"flush_all"):
307-
writer.flush_all()

megatron/training.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from megatron import get_args
1818
from megatron import get_signal_handler
1919
from megatron import get_timers
20+
from megatron import get_counters
2021
from megatron import get_tensorboard_writer
2122
from megatron import get_current_global_batch_size
2223
from megatron import get_num_microbatches
@@ -590,17 +591,22 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
590591
if iteration % args.log_interval == 0:
591592
elapsed_time = timers('interval-time').elapsed(barrier=True)
592593
elapsed_time_per_iteration = elapsed_time / total_iterations
594+
counters = get_counters()
595+
tokens = counters.pop('tokens') # reset counter for future iterations
596+
tokens_per_sec = tokens/(elapsed_time)
593597
if writer:
594598
if args.log_timers_to_tensorboard:
595599
writer.add_scalar('iteration-time',
596600
elapsed_time_per_iteration, iteration)
601+
writer.add_scalar('tokens-per-sec', tokens_per_sec, iteration)
597602

598603
log_string = ' iteration {:8d}/{:8d} |'.format(
599604
iteration, args.train_iters)
600605
log_string += ' consumed samples: {:12d} |'.format(
601606
args.consumed_train_samples)
602607
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
603608
elapsed_time_per_iteration * 1000.0)
609+
log_string += f' rate (tokens/sec): {tokens_per_sec:.2f} |'
604610
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
605611
log_string += ' global batch size: {:5d} |'.format(batch_size)
606612
for key in total_loss_dict:
@@ -668,6 +674,7 @@ def _train(args, forward_step_func,
668674
# Iterations.
669675
iteration = args.iteration
670676

677+
counters = get_counters()
671678
timers('interval-time', log_level=0).start(barrier=True)
672679
print_datetime('before the start of training step')
673680
report_memory_flag = True
@@ -706,10 +713,13 @@ def _train(args, forward_step_func,
706713
if args.eval_interval and iteration % args.eval_interval == 0 and \
707714
args.do_valid:
708715
prefix = 'iteration {}'.format(iteration)
716+
current_tokens = counters['tokens']
709717
evaluate_and_print_results(prefix, forward_step_func,
710718
valid_data_iterator, model,
711719
iteration, process_non_loss_data_func,
712720
verbose=False, args=args)
721+
counters['tokens'] = current_tokens
722+
713723

714724
# if using wandb writer, flush the stats of train_step & potentially evaluate
715725
writer = get_tensorboard_writer()

0 commit comments

Comments
 (0)