Skip to content

Commit

Permalink
era5 training bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwzd committed Sep 9, 2024
1 parent 9f84835 commit b9c1e30
Showing 1 changed file with 29 additions and 21 deletions.
50 changes: 29 additions & 21 deletions train/era5.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@

from einops import rearrange

from pathlib import Path

class LitGraphForecaster(pl.LightningModule):

class LitFengWuGHR(pl.LightningModule):
"""
LightningModule for graph-based weather forecasting.
Expand All @@ -23,7 +25,7 @@ class LitGraphForecaster(pl.LightningModule):
lr : Learning rate for optimizer.
Methods:
__init__: Initialize the LitGraphForecaster object.
__init__: Initialize the LitFengWuGHR object.
forward: Forward pass of the model.
training_step: Training step.
configure_optimizers: Configure the optimizer for training.
Expand All @@ -44,7 +46,7 @@ def __init__(

):
"""
Initialize the LitGraphForecaster object with the required args.
Initialize the LitFengWuGHR object with the required args.
Args:
lat_lons : List of latitude and longitude values.
Expand Down Expand Up @@ -135,16 +137,27 @@ def __getitem__(self, index):

if __name__ == "__main__":

ckpt_path = Path("./checkpoints")
patch_size = 4
grid_step = 20
variables = ["2m_temperature",
"surface_pressure",
"10m_u_component_of_wind",
"10m_v_component_of_wind"]

channels = len(variables)
ckpt_path.mkdir(parents=True, exist_ok=True)

reanalysis = xarray.open_zarr(
'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3',
storage_options=dict(token='anon'),

)
reanalysis = reanalysis.isel(time=slice(100, 400), longitude=slice(

reanalysis = reanalysis.sel(time=slice('2020-01-01', '2021-01-01'))
reanalysis = reanalysis.isel(time=slice(100,107), longitude=slice(
0, 1440, grid_step), latitude=slice(0, 721, grid_step))

reanalysis = reanalysis[variables]
print(f'size: {reanalysis.nbytes / (1024 ** 3)} GiB')

lat_lons = np.array(
Expand All @@ -155,31 +168,26 @@ def __getitem__(self, index):
).T.reshape((-1, 2))

checkpoint_callback = ModelCheckpoint(
dirpath="./checkpoints", save_top_k=1, monitor="loss")
reanalysis = reanalysis[["2m_temperature",
"surface_pressure",
"10m_u_component_of_wind",
"10m_v_component_of_wind"]]

shape = np.asarray(reanalysis.to_array()).shape
channels = shape[0]
dirpath=ckpt_path, save_top_k=1, monitor="loss")

dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8)
model = LitGraphForecaster(lat_lons=lat_lons,
channels=channels,
image_size=(721//grid_step, 1440//grid_step),
patch_size=patch_size,
depth=5,
heads=4,
mlp_dim=5)
model = LitFengWuGHR(lat_lons=lat_lons,
channels=channels,
image_size=(721//grid_step, 1440//grid_step),
patch_size=patch_size,
depth=5,
heads=4,
mlp_dim=5)
trainer = pl.Trainer(
accelerator="gpu",
devices=-1,
max_epochs=1000,
max_epochs=100,
precision="16-mixed",
callbacks=[checkpoint_callback],
log_every_n_steps=3

)

trainer.fit(model, dset)

torch.save(model.state_dict(), ckpt_path / "best.pt")

0 comments on commit b9c1e30

Please sign in to comment.