From c533fcc48295267bfffe77d5815c5567ec67f4e0 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 7 Nov 2023 15:47:42 -0800 Subject: [PATCH] use a sync batchnorm as way of normalizing all channels of the HRRR, applied to target, but not to predictions --- README.md | 3 ++- metnet3_pytorch/metnet3_pytorch.py | 31 ++++++++++++++++++++++++++++-- setup.py | 2 +- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index fbf0474..a2c07ff 100644 --- a/README.md +++ b/README.md @@ -73,8 +73,9 @@ surface_target, hrrr_target, precipitation_target = metnet3( ## Todo - [x] figure out all the cross entropy and MSE losses +- [x] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack) -- [ ] auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training, as well as allow researcher to pass in their own normalization variables +- [ ] allow researcher to pass in their own normalization variables for HRRR - [ ] figure out the topological embedding, consult a neural weather researcher ## Citations diff --git a/metnet3_pytorch/metnet3_pytorch.py b/metnet3_pytorch/metnet3_pytorch.py index a3a099a..7477209 100644 --- a/metnet3_pytorch/metnet3_pytorch.py +++ b/metnet3_pytorch/metnet3_pytorch.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from functools import partial from collections import namedtuple @@ -40,6 +41,21 @@ def MaybeSyncBatchnorm2d(is_distributed = None): is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1) return nn.SyncBatchNorm if is_distributed else nn.BatchNorm2d +@contextmanager +def freeze_batchnorm(bn): + assert not exists(next(bn.parameters(), None)) + + was_training = bn.training + was_tracking_stats = bn.track_running_stats # in some versions of pytorch, running mean and variance still gets updated even in eval mode it seems.. + + bn.eval() + bn.track_running_stats = False + + yield bn + + bn.train(was_training) + bn.track_running_stats = was_tracking_stats + # loss scaling in section 4.3.2 class LossScaleFunction(Function): @@ -686,6 +702,8 @@ def __init__( nn.Conv2d(dim, precipitation_target_channels, 1) ) + self.batchnorm_hrrr = MaybeSyncBatchnorm2d()(hrrr_target_channels, affine = False) + self.mse_loss_scaler = LossScaler() def forward( @@ -779,9 +797,18 @@ def forward( # calculate HRRR mse loss - hrrr_pred = self.mse_loss_scaler(hrrr_pred) + # use a batchnorm to normalize each channel to mean zero and unit variance + + normed_hrrr_target = self.batchnorm_hrrr(hrrr_target) + + with freeze_batchnorm(self.batchnorm_hrrr) as frozen_batchnorm: + normed_hrrr_pred = frozen_batchnorm(hrrr_pred) + + # proposed loss gradient rescaler from section 4.3.2 + + normed_hrrr_pred = self.mse_loss_scaler(normed_hrrr_pred) - hrrr_loss = F.mse_loss(hrrr_pred, hrrr_target) + hrrr_loss = F.mse_loss(normed_hrrr_pred, normed_hrrr_target) # total loss diff --git a/setup.py b/setup.py index ccfd345..6f52360 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'metnet3-pytorch', packages = find_packages(exclude=[]), - version = '0.0.2', + version = '0.0.3', license='MIT', description = 'MetNet 3 - Pytorch', author = 'Phil Wang',