From db3b37a6620f2d09e01ce9e6c05e7a17fce6ae3f Mon Sep 17 00:00:00 2001 From: Tien Tr Date: Thu, 4 Aug 2022 09:35:02 +0700 Subject: [PATCH 1/2] Update `progress_bar` `checkpoint_callback`, `weights_summary`, `progress_bar_refresh_rate=progress_bar_refresh_rate` are removed in pytorch_lightning v1.7 --- aitextgen/aitextgen.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 12a0a3d..c383153 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -709,10 +709,9 @@ def train( gpus=n_gpu, max_steps=num_steps, gradient_clip_val=max_grad_norm, - checkpoint_callback=False, + enable_checkpointing=False, #checkpoint_callback deprecated in pytorch_lighning v1.7 logger=loggers if loggers else False, - weights_summary=None, - progress_bar_refresh_rate=progress_bar_refresh_rate, # ignored + enable_model_summary=None, #weights_summary and progress_bar_refresh_rate are removed in pytorch_lighning v1.7 callbacks=[ ATGProgressBar( save_every, From 45115a7a27c31f7c95d01099637f79293538cc57 Mon Sep 17 00:00:00 2001 From: Tien Tr Date: Thu, 4 Aug 2022 09:39:41 +0700 Subject: [PATCH 2/2] Replace `progress_bar_dict` with `get_metrics` As shown here https://pytorch-lightning.readthedocs.io/en/stable/generated/CHANGELOG.html --- aitextgen/train.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/aitextgen/train.py b/aitextgen/train.py index 2e289cb..ce54a4a 100644 --- a/aitextgen/train.py +++ b/aitextgen/train.py @@ -141,6 +141,12 @@ def on_train_start(self, trainer, pl_module): def on_train_end(self, trainer, pl_module): self.main_progress_bar.close() self.unfreeze_layers(pl_module) + + def get_metrics(self, trainer, pl_module): + # don't show the version number + items = super().get_metrics(trainer, pl_module) + items.pop("v_num", None) + return items def on_batch_end(self, trainer, pl_module): super().on_batch_end(trainer, pl_module) @@ -150,7 +156,8 @@ def on_batch_end(self, trainer, pl_module): if self.steps == 0 and self.gpu: torch.cuda.empty_cache() - current_loss = float(trainer.progress_bar_dict["loss"]) + metrics = self.get_metrics(trainer, pl_module) + current_loss = float(metrics["loss"]) self.steps += 1 avg_loss = 0 if current_loss == current_loss: # don't add if current_loss is NaN