Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ embed_dropout_rate: 0.1

target_cell_local_prediction: True

ae_local_dim_embed: 1024
ae_local_num_blocks: 2
ae_local_dim_embed: 256
ae_local_num_blocks: 0
ae_local_num_heads: 16
ae_local_dropout_rate: 0.1
ae_local_with_qk_lnorm: True
Expand All @@ -24,8 +24,8 @@ ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1

ae_global_dim_embed: 2048
ae_global_num_blocks: 8
ae_global_num_heads: 32
ae_global_num_blocks: 4
ae_global_num_heads: 16
ae_global_dropout_rate: 0.1
ae_global_with_qk_lnorm: True
# TODO: switching to < 1 triggers triton-related issues.
Expand All @@ -42,18 +42,19 @@ pred_mlp_adaln: True

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
forecast_offset : 0
forecast_offset : 1
forecast_delta_hrs: 0
forecast_steps: 0
forecast_policy: null
forecast_steps: 4
forecast_policy: "fixed"
forecast_freeze_model: False
forecast_att_dense_rate: 1.0
fe_num_blocks: 0
fe_num_blocks: 8
fe_num_heads: 16
fe_dropout_rate: 0.1
fe_with_qk_lnorm: True
impute_latent_noise_std: 0.0 # 1e-4

healpix_level: 5
healpix_level: 4

with_mixed_precision: True
with_flash_attention: True
Expand Down Expand Up @@ -88,7 +89,7 @@ freeze_modules: ""

# training mode: "forecast" or "masking" (masked token modeling)
# for "masking" to train with auto-encoder mode, forecast_offset should be 0
training_mode: "masking"
training_mode: "forecast"
# masking rate when training mode is "masking"; ignored in foreacast mode
masking_rate: 0.6
# sample the masking rate (with normal distribution centered at masking_rate)
Expand All @@ -108,17 +109,17 @@ masking_strategy_config: {"strategies": ["random", "healpix", "channel"],
"same_strategy_per_batch": false
}

num_epochs: 32
num_epochs: 64
samples_per_epoch: 4096
samples_per_validation: 512
shuffle: True

lr_scaling_policy: "sqrt"
lr_start: 1e-6
lr_max: 5e-5
lr_final_decay: 1e-6
lr_max: 0.0001
lr_final_decay: 2e-6
lr_final: 0.0
lr_steps_warmup: 512
lr_steps_warmup: 256
lr_steps_cooldown: 512
lr_policy_warmup: "cosine"
lr_policy_decay: "linear"
Expand Down
28 changes: 28 additions & 0 deletions config/eval_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
verbose: true
image_format : "png" #options: "png", "pdf", "svg", "eps", "jpg" ..
dpi_val : 300
summary_plots : true
print_summary: false

evaluation:
metrics : ["rmse"]
regions: ["global"]

run_ids :

ptluswdo:
label: "ptluswdo: 64ep 2fs (naoj54ch) + 32ep 8fs 2e-5"
epoch: 0
rank: 0
streams:
ERA5:
channels: ["2t", "10u", "10v", "z_500", "t_850", "u_850", "v_850", "q_850", ]
#channels: ["2t", "q_850", ]
evaluation:
sample: "all"
forecast_step: "all"
plotting:
sample: [0]
forecast_step: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
plot_maps: true
plot_histograms: false
6 changes: 6 additions & 0 deletions config/runs_plot_train.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
train :
plot :
lnjzhore :
slurm_id: 0
description: "Christian's naoj54ch with new code"
eval: vgbndhco
1 change: 1 addition & 0 deletions config/streams/era5_1deg/era5.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ERA5 :
type : anemoi
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-6h-v8.zarr']
# filenames : ['aifs-ea-an-oper-0001-mars-o48-1979-2024-6h-v1.zarr']
source_exclude : ['w_', 'skt', 'tcw', 'cp', 'tp']
target_exclude : ['w_', 'slor', 'sdor', 'tcw', 'cp', 'tp']
loss_weight : 1.
Expand Down
39 changes: 26 additions & 13 deletions src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from weathergen.common.config import Config

from weathergen.model.attention import (
MultiCrossAttentionHeadVarlen,
MultiCrossAttentionHeadVarlenSlicedQ,
Expand All @@ -24,7 +24,7 @@
StreamEmbedLinear,
StreamEmbedTransformer,
)
from weathergen.model.layers import MLP
from weathergen.model.layers import FEMLP, MLP
from weathergen.model.utils import ActivationFactory
from weathergen.utils.utils import get_dtype

Expand Down Expand Up @@ -317,18 +317,31 @@ def create(self) -> torch.nn.ModuleList:
attention_dtype=get_dtype(self.cf.attention_dtype),
)
)
# Add MLP block
self.fe_blocks.append(
MLP(
self.cf.ae_global_dim_embed,
self.cf.ae_global_dim_embed,
with_residual=True,
dropout_rate=self.cf.fe_dropout_rate,
norm_type=self.cf.norm_type,
dim_aux=1,
norm_eps=self.cf.mlp_norm_eps,

if i + 1 == self.cf.ae_global_num_blocks:
self.fe_blocks.append(
FEMLP(
self.cf.ae_global_dim_embed,
self.cf.ae_global_dim_embed,
with_residual=True,
dropout_rate=self.cf.fe_dropout_rate,
norm_type=self.cf.norm_type,
dim_aux=1,
norm_eps=self.cf.mlp_norm_eps,
)
)
else:
self.fe_blocks.append(
MLP(
self.cf.ae_global_dim_embed,
self.cf.ae_global_dim_embed,
with_residual=True,
dropout_rate=self.cf.fe_dropout_rate,
norm_type=self.cf.norm_type,
dim_aux=1,
norm_eps=self.cf.mlp_norm_eps,
)
)
)

def init_weights_final(m):
if isinstance(m, torch.nn.Linear):
Expand Down
77 changes: 77 additions & 0 deletions src/weathergen/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,80 @@ def forward(self, *args):
x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]])

return x


class FEMLP(torch.nn.Module):
def __init__(
self,
dim_in,
dim_out,
num_layers=2,
hidden_factor=2,
pre_layer_norm=True,
dropout_rate=0.0,
nonlin=torch.nn.GELU,
with_residual=False,
norm_type="LayerNorm",
dim_aux=None,
norm_eps=1e-5,
name: str | None = None,
):
"""Constructor"""

super(FEMLP, self).__init__()

if name is not None:
self.name = name

assert num_layers >= 2

self.with_residual = with_residual
self.with_aux = dim_aux is not None
dim_hidden = int(dim_in * hidden_factor)

self.layers = torch.nn.ModuleList()

norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm

if pre_layer_norm:
self.layers.append(
norm(dim_in, eps=norm_eps)
if dim_aux is None
else AdaLayerNorm(dim_in, dim_aux, norm_eps=norm_eps)
)

self.layers.append(torch.nn.Linear(dim_in, dim_hidden))
self.layers.append(nonlin())
self.layers.append(torch.nn.Dropout(p=dropout_rate))

for _ in range(num_layers - 2):
self.layers.append(torch.nn.Linear(dim_hidden, dim_hidden))
self.layers.append(nonlin())
self.layers.append(torch.nn.Dropout(p=dropout_rate))

self.layers.append(torch.nn.Linear(dim_hidden, dim_out))

# Add LayerNorm after skip connection if residuals are used
if self.with_residual:
# self.residual_norm = AdaLayerNorm(
# dim_out, dim_aux, norm_eps=norm_eps
# ) # norm(dim_out, eps=norm_eps)
self.residual_norm = torch.nn.LayerNorm(dim_out, eps=norm_eps, elementwise_affine=False)

def forward(self, *args):
x, x_in, aux = args[0], args[0], args[-1]

for i, layer in enumerate(self.layers):
x = layer(x, aux) if (i == 0 and self.with_aux) else layer(x)

if self.with_residual:
if x.shape[-1] == x_in.shape[-1]:
x = x_in + x
else:
assert x.shape[-1] % x_in.shape[-1] == 0
x = x + x_in.repeat([*[1 for _ in x.shape[:-1]], x.shape[-1] // x_in.shape[-1]])

# Apply LayerNorm to the residual connection
x = self.residual_norm(x)

return x
Loading