Skip to content

Commit

Permalink
Merge pull request #191 from tientr/master
Browse files Browse the repository at this point in the history
Fix #189
  • Loading branch information
minimaxir authored Aug 9, 2022
2 parents 29d8dfa + 45115a7 commit 820d7da
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
5 changes: 2 additions & 3 deletions aitextgen/aitextgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,10 +709,9 @@ def train(
gpus=n_gpu,
max_steps=num_steps,
gradient_clip_val=max_grad_norm,
checkpoint_callback=False,
enable_checkpointing=False, #checkpoint_callback deprecated in pytorch_lighning v1.7
logger=loggers if loggers else False,
weights_summary=None,
progress_bar_refresh_rate=progress_bar_refresh_rate, # ignored
enable_model_summary=None, #weights_summary and progress_bar_refresh_rate are removed in pytorch_lighning v1.7
callbacks=[
ATGProgressBar(
save_every,
Expand Down
9 changes: 8 additions & 1 deletion aitextgen/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ def on_train_start(self, trainer, pl_module):
def on_train_end(self, trainer, pl_module):
self.main_progress_bar.close()
self.unfreeze_layers(pl_module)

def get_metrics(self, trainer, pl_module):
# don't show the version number
items = super().get_metrics(trainer, pl_module)
items.pop("v_num", None)
return items

def on_batch_end(self, trainer, pl_module):
super().on_batch_end(trainer, pl_module)
Expand All @@ -150,7 +156,8 @@ def on_batch_end(self, trainer, pl_module):
if self.steps == 0 and self.gpu:
torch.cuda.empty_cache()

current_loss = float(trainer.progress_bar_dict["loss"])
metrics = self.get_metrics(trainer, pl_module)
current_loss = float(metrics["loss"])
self.steps += 1
avg_loss = 0
if current_loss == current_loss: # don't add if current_loss is NaN
Expand Down

0 comments on commit 820d7da

Please sign in to comment.