Skip to content

Commit

Permalink
Merge pull request #43 from krasserm/wip-pytorch-2
Browse files Browse the repository at this point in the history
Upgrade to PyTorch 2.0 and PyTorch Lightning 2.0
  • Loading branch information
krasserm authored Apr 6, 2023
2 parents 9f49d0b + 15f0e4d commit 737a766
Show file tree
Hide file tree
Showing 5 changed files with 1,051 additions and 836 deletions.
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ channels:
dependencies:
- python=3.9
- pytorch-cuda=11.7
- pytorch=1.13
- torchvision=0.14
- pytorch=2.0
- torchvision=0.15
- pip>=22
1 change: 0 additions & 1 deletion examples/training/clm/train_fsdp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ python -m perceiver.scripts.text.clm_fsdp fit \
--trainer.precision=bf16 \
--trainer.max_steps=50000 \
--trainer.accumulate_grad_batches=1 \
--trainer.track_grad_norm=2 \
--trainer.check_val_every_n_epoch=null \
--trainer.val_check_interval=500 \
--trainer.limit_val_batches=20 \
Expand Down
8 changes: 5 additions & 3 deletions perceiver/scripts/text/clm_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import torch
from pytorch_lightning.cli import LightningArgumentParser, LRSchedulerCallable, OptimizerCallable
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy, StrategyRegistry
from pytorch_lightning.strategies import FSDPStrategy, StrategyRegistry
from pytorch_lightning.utilities import grad_norm
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from perceiver.model.core import CrossAttentionLayer, SelfAttentionLayer
Expand All @@ -27,7 +28,7 @@

StrategyRegistry.register(
name="fsdp_perceiver_ar",
strategy=DDPFullyShardedNativeStrategy,
strategy=FSDPStrategy,
description="FSDP strategy optimized for Perceiver AR models",
activation_checkpointing=[CrossAttentionLayer, SelfAttentionLayer],
auto_wrap_policy=policy,
Expand Down Expand Up @@ -60,9 +61,10 @@ def configure_optimizers(self):
"lr_scheduler": {"scheduler": scheduler, "interval": "step", "frequency": 1},
}

def on_before_optimizer_step(self, optimizer, optimizer_idx):
def on_before_optimizer_step(self, optimizer):
if self.hparams.max_grad_norm is not None:
self.trainer.model.clip_grad_norm_(self.hparams.max_grad_norm)
self.log_dict(grad_norm(self, norm_type=2))


class CausalLanguageModelCLI(CLI):
Expand Down
Loading

0 comments on commit 737a766

Please sign in to comment.