Skip to content

Commit

Permalink
Info about batches added
Browse files Browse the repository at this point in the history
  • Loading branch information
Demirrr committed Oct 28, 2024
1 parent 9834ad1 commit 628aa4b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions dicee/trainer/torch_trainer_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def train(self):
-------
"""
num_of_batches=len(self.train_dataset_loader)
for epoch in (tqdm_bar := make_iterable_verbose(range(self.num_epochs),
verbose=self.local_rank == self.global_rank == 0,
position=0,
Expand All @@ -179,11 +180,11 @@ def train(self):
if hasattr(tqdm_bar, 'set_description_str'):
tqdm_bar.set_description_str(f"Epoch:{epoch + 1}")
if i > 0:
tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}")
tqdm_bar.set_postfix_str(f"batch={i} | {num_of_batches}, loss_step={batch_loss:.5f}, loss_epoch={epoch_loss / i:.5f}")
else:
tqdm_bar.set_postfix_str(f"loss_step={batch_loss:.5f}, loss_epoch={batch_loss:.5f}")

avg_epoch_loss = epoch_loss / len(self.train_dataset_loader)
avg_epoch_loss = epoch_loss / num_of_batches

if self.local_rank == self.global_rank == 0:
self.model.module.loss_history.append(avg_epoch_loss)
Expand Down

0 comments on commit 628aa4b

Please sign in to comment.