Skip to content

Commit

Permalink
use a sync batchnorm as way of normalizing all channels of the HRRR, …
Browse files Browse the repository at this point in the history
…applied to target, but not to predictions
  • Loading branch information
lucidrains committed Nov 7, 2023
1 parent c053ca8 commit c533fcc
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ surface_target, hrrr_target, precipitation_target = metnet3(
## Todo

- [x] figure out all the cross entropy and MSE losses
- [x] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack)

- [ ] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training, as well as allow researcher to pass in their own normalization variables
- [ ] allow researcher to pass in their own normalization variables for HRRR
- [ ] figure out the topological embedding, consult a neural weather researcher

## Citations
Expand Down
31 changes: 29 additions & 2 deletions metnet3_pytorch/metnet3_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import contextmanager
from functools import partial
from collections import namedtuple

Expand Down Expand Up @@ -40,6 +41,21 @@ def MaybeSyncBatchnorm2d(is_distributed = None):
is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
return nn.SyncBatchNorm if is_distributed else nn.BatchNorm2d

@contextmanager
def freeze_batchnorm(bn):
assert not exists(next(bn.parameters(), None))

was_training = bn.training
was_tracking_stats = bn.track_running_stats # in some versions of pytorch, running mean and variance still gets updated even in eval mode it seems..

bn.eval()
bn.track_running_stats = False

yield bn

bn.train(was_training)
bn.track_running_stats = was_tracking_stats

# loss scaling in section 4.3.2

class LossScaleFunction(Function):
Expand Down Expand Up @@ -686,6 +702,8 @@ def __init__(
nn.Conv2d(dim, precipitation_target_channels, 1)
)

self.batchnorm_hrrr = MaybeSyncBatchnorm2d()(hrrr_target_channels, affine = False)

self.mse_loss_scaler = LossScaler()

def forward(
Expand Down Expand Up @@ -779,9 +797,18 @@ def forward(

# calculate HRRR mse loss

hrrr_pred = self.mse_loss_scaler(hrrr_pred)
# use a batchnorm to normalize each channel to mean zero and unit variance

normed_hrrr_target = self.batchnorm_hrrr(hrrr_target)

with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm:
normed_hrrr_pred = frozen_batchnorm(hrrr_pred)

# proposed loss gradient rescaler from section 4.3.2

normed_hrrr_pred = self.mse_loss_scaler(normed_hrrr_pred)

hrrr_loss = F.mse_loss(hrrr_pred, hrrr_target)
hrrr_loss = F.mse_loss(normed_hrrr_pred, normed_hrrr_target)

# total loss

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'metnet3-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'MetNet 3 - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit c533fcc

Please sign in to comment.