From 9f84835e4dd553ba32692195061e48fb490a330d Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Tue, 3 Sep 2024 11:55:05 +0000 Subject: [PATCH] era5 training --- .gitignore | 2 + train/era5.py | 185 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 train/era5.py diff --git a/.gitignore b/.gitignore index d248bf98..e450bd8f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ # pixi environments .pixi .vscode/ +checkpoints/ +lightning_logs/ diff --git a/train/era5.py b/train/era5.py new file mode 100644 index 00000000..343d5975 --- /dev/null +++ b/train/era5.py @@ -0,0 +1,185 @@ +import click +import xarray +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.utils.data import DataLoader, Dataset + +from graph_weather.models import MetaModel +from graph_weather.models.losses import NormalizedMSELoss + +from einops import rearrange + + +class LitGraphForecaster(pl.LightningModule): + """ + LightningModule for graph-based weather forecasting. + + Attributes: + model (GraphWeatherForecaster): Graph weather forecaster model. + criterion (NormalizedMSELoss): Loss criterion for training. + lr : Learning rate for optimizer. + + Methods: + __init__: Initialize the LitGraphForecaster object. + forward: Forward pass of the model. + training_step: Training step. + configure_optimizers: Configure the optimizer for training. + """ + + def __init__( + self, + lat_lons: list, + *, + channels: int, + image_size, + patch_size=4, + depth=5, + heads=4, + mlp_dim=5, + feature_dim: int = 605, + lr: float = 3e-4, + + ): + """ + Initialize the LitGraphForecaster object with the required args. + + Args: + lat_lons : List of latitude and longitude values. + feature_dim : Dimensionality of the input features. + aux_dim : Dimensionality of auxiliary features. + hidden_dim : Dimensionality of hidden layers in the model. + num_blocks : Number of graph convolutional blocks in the model. + lr (float): Learning rate for optimizer. + """ + super().__init__() + self.model = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels + ) + self.criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) + ) + self.lr = lr + self.save_hyperparameters() + + def forward(self, x): + """ + Forward pass . + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + return self.model(x) + + def training_step(self, batch, batch_idx): + """ + Training step. + + Args: + batch (array): Batch of data containing input and output tensors. + batch_idx (int): Index of the current batch. + + Returns: + torch.Tensor: Loss tensor. + """ + x, y = batch[:, 0], batch[:, 1] + if torch.isnan(x).any() or torch.isnan(y).any(): + return None + y_hat = self.forward(x) + loss = self.criterion(y_hat, y) + self.log('loss', loss, prog_bar=True) + return loss + + def configure_optimizers(self): + """ + Configure the optimizer. + + Returns: + torch.optim.Optimizer: Optimizer instance. + """ + return torch.optim.AdamW(self.parameters(), lr=self.lr) + + +class Era5Dataset(Dataset): + """Era5 dataset.""" + + def __init__(self, xarr, transform=None): + """ + Arguments: + #TODO + """ + ds = np.asarray(xarr.to_array()) + ds = torch.from_numpy(ds) + ds -= ds.min(0, keepdim=True)[0] + ds /= ds.max(0, keepdim=True)[0] + ds = rearrange(ds, "C T H W -> T (H W) C") + self.ds = ds + + def __len__(self): + return len(self.ds) - 1 + + def __getitem__(self, index): + return self.ds[index:index+2] + + +if __name__ == "__main__": + + patch_size = 4 + grid_step = 20 + + reanalysis = xarray.open_zarr( + 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3', + storage_options=dict(token='anon'), + + ) + reanalysis = reanalysis.isel(time=slice(100, 400), longitude=slice( + 0, 1440, grid_step), latitude=slice(0, 721, grid_step)) + print(f'size: {reanalysis.nbytes / (1024 ** 3)} GiB') + + lat_lons = np.array( + np.meshgrid( + np.asarray(reanalysis["latitude"]).flatten(), + np.asarray(reanalysis["longitude"]).flatten(), + ) + ).T.reshape((-1, 2)) + + checkpoint_callback = ModelCheckpoint( + dirpath="./checkpoints", save_top_k=1, monitor="loss") + reanalysis = reanalysis[["2m_temperature", + "surface_pressure", + "10m_u_component_of_wind", + "10m_v_component_of_wind"]] + + shape = np.asarray(reanalysis.to_array()).shape + channels = shape[0] + + dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8) + model = LitGraphForecaster(lat_lons=lat_lons, + channels=channels, + image_size=(721//grid_step, 1440//grid_step), + patch_size=patch_size, + depth=5, + heads=4, + mlp_dim=5) + trainer = pl.Trainer( + accelerator="gpu", + devices=-1, + max_epochs=1000, + precision="16-mixed", + callbacks=[checkpoint_callback], + log_every_n_steps=3 + + ) + + trainer.fit(model, dset)