Skip to content

Commit

Permalink
era5 training
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwzd committed Sep 3, 2024
1 parent 2970757 commit 9f84835
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
# pixi environments
.pixi
.vscode/
checkpoints/
lightning_logs/
185 changes: 185 additions & 0 deletions train/era5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import click
import xarray
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset

from graph_weather.models import MetaModel
from graph_weather.models.losses import NormalizedMSELoss

from einops import rearrange


class LitGraphForecaster(pl.LightningModule):
"""
LightningModule for graph-based weather forecasting.
Attributes:
model (GraphWeatherForecaster): Graph weather forecaster model.
criterion (NormalizedMSELoss): Loss criterion for training.
lr : Learning rate for optimizer.
Methods:
__init__: Initialize the LitGraphForecaster object.
forward: Forward pass of the model.
training_step: Training step.
configure_optimizers: Configure the optimizer for training.
"""

def __init__(
self,
lat_lons: list,
*,
channels: int,
image_size,
patch_size=4,
depth=5,
heads=4,
mlp_dim=5,
feature_dim: int = 605,
lr: float = 3e-4,

):
"""
Initialize the LitGraphForecaster object with the required args.
Args:
lat_lons : List of latitude and longitude values.
feature_dim : Dimensionality of the input features.
aux_dim : Dimensionality of auxiliary features.
hidden_dim : Dimensionality of hidden layers in the model.
num_blocks : Number of graph convolutional blocks in the model.
lr (float): Learning rate for optimizer.
"""
super().__init__()
self.model = MetaModel(
lat_lons,
image_size=image_size,
patch_size=patch_size,
depth=depth,
heads=heads,
mlp_dim=mlp_dim,
channels=channels
)
self.criterion = NormalizedMSELoss(
lat_lons=lat_lons, feature_variance=np.ones((feature_dim,))
)
self.lr = lr
self.save_hyperparameters()

def forward(self, x):
"""
Forward pass .
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor.
"""
return self.model(x)

def training_step(self, batch, batch_idx):
"""
Training step.
Args:
batch (array): Batch of data containing input and output tensors.
batch_idx (int): Index of the current batch.
Returns:
torch.Tensor: Loss tensor.
"""
x, y = batch[:, 0], batch[:, 1]
if torch.isnan(x).any() or torch.isnan(y).any():
return None
y_hat = self.forward(x)
loss = self.criterion(y_hat, y)
self.log('loss', loss, prog_bar=True)
return loss

def configure_optimizers(self):
"""
Configure the optimizer.
Returns:
torch.optim.Optimizer: Optimizer instance.
"""
return torch.optim.AdamW(self.parameters(), lr=self.lr)


class Era5Dataset(Dataset):
"""Era5 dataset."""

def __init__(self, xarr, transform=None):
"""
Arguments:
#TODO
"""
ds = np.asarray(xarr.to_array())
ds = torch.from_numpy(ds)
ds -= ds.min(0, keepdim=True)[0]
ds /= ds.max(0, keepdim=True)[0]
ds = rearrange(ds, "C T H W -> T (H W) C")
self.ds = ds

def __len__(self):
return len(self.ds) - 1

def __getitem__(self, index):
return self.ds[index:index+2]


if __name__ == "__main__":

patch_size = 4
grid_step = 20

reanalysis = xarray.open_zarr(
'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3',
storage_options=dict(token='anon'),

)
reanalysis = reanalysis.isel(time=slice(100, 400), longitude=slice(
0, 1440, grid_step), latitude=slice(0, 721, grid_step))
print(f'size: {reanalysis.nbytes / (1024 ** 3)} GiB')

lat_lons = np.array(
np.meshgrid(
np.asarray(reanalysis["latitude"]).flatten(),
np.asarray(reanalysis["longitude"]).flatten(),
)
).T.reshape((-1, 2))

checkpoint_callback = ModelCheckpoint(
dirpath="./checkpoints", save_top_k=1, monitor="loss")
reanalysis = reanalysis[["2m_temperature",
"surface_pressure",
"10m_u_component_of_wind",
"10m_v_component_of_wind"]]

shape = np.asarray(reanalysis.to_array()).shape
channels = shape[0]

dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8)
model = LitGraphForecaster(lat_lons=lat_lons,
channels=channels,
image_size=(721//grid_step, 1440//grid_step),
patch_size=patch_size,
depth=5,
heads=4,
mlp_dim=5)
trainer = pl.Trainer(
accelerator="gpu",
devices=-1,
max_epochs=1000,
precision="16-mixed",
callbacks=[checkpoint_callback],
log_every_n_steps=3

)

trainer.fit(model, dset)

0 comments on commit 9f84835

Please sign in to comment.