Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e30a8be
add embed_gradient_checkpoint_mode to config
Oct 24, 2025
427f6dd
add embed_gradient_checkpoint_mode condition to forward_channels
Oct 24, 2025
cf509e3
add embed_gradient_checkpoint_mode to args
Oct 24, 2025
125d50d
test pipeline for embed_gradient_checkpoint_mode true
Oct 24, 2025
53c94a5
test pipeline for embed_gradient_checkpoint_mode false
Oct 24, 2025
797cf1b
add ae_local_blocks_grdient_checkpoint_mode to config
Oct 25, 2025
8575019
add ae_local_blocks_grdient_checkpoint_mode cond
Oct 25, 2025
c144f28
add ae_adapter_grdient_checkpoint_mode
Oct 25, 2025
11161e5
add assimilate_global_gradient_checkpoint_mode to config
Oct 25, 2025
4b32a13
add pred_gradient_checkpoint_mode to config
Oct 25, 2025
a8f5633
Merge branch 'develop' of https://github.com/javak87/WeatherGenerator…
Oct 27, 2025
f2a78d3
ruff embeddings.py
Oct 27, 2025
bca3381
remove ae_local_blocks_grdient_checkpoint_mode for __init__
Oct 27, 2025
909cff3
ruff the code
Oct 27, 2025
e4b7746
ruff the code
Oct 27, 2025
cf5b13f
ruff the code
Oct 27, 2025
e0532c2
ruff the code
Oct 27, 2025
1c012ce
Merge branch 'javad/dev/cond_checkpoint_ae_local-1141' into javad/dev…
Oct 27, 2025
9827650
Merge branch 'javad/dev/cond_checkpoint_assimilate_global-1141' into …
Oct 27, 2025
e1630a1
Merge branch 'javad/dev/cond_checkpoint_embed_transformer-1141' into …
Oct 27, 2025
92cd98e
Merge branch 'javad/dev/cond_checkpoint_predict-1141' into javad/dev/…
Oct 27, 2025
a1b7f1f
Merge branch 'develop' into javad/dev/cond_checkpoint_all-1141
javak87 Oct 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ embed_centroids_local_coords: False
embed_size_centroids: 0
embed_unembed_mode: "block"
embed_dropout_rate: 0.1
embed_gradient_checkpoint_mode: False

target_cell_local_prediction: True

Expand All @@ -17,11 +18,13 @@ ae_local_with_qk_lnorm: True

ae_local_num_queries: 1
ae_local_queries_per_cell: False
ae_local_blocks_grdient_checkpoint_mode: False
ae_adapter_num_heads: 16
ae_adapter_embed: 128
ae_adapter_with_qk_lnorm: True
ae_adapter_with_residual: True
ae_adapter_dropout_rate: 0.1
ae_adapter_grdient_checkpoint_mode: False

ae_global_dim_embed: 2048
ae_global_num_blocks: 8
Expand All @@ -33,12 +36,14 @@ ae_global_with_qk_lnorm: True
ae_global_att_dense_rate: 1.0
ae_global_block_factor: 64
ae_global_mlp_hidden_factor: 2
assimilate_global_gradient_checkpoint_mode: False

decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning
pred_adapter_kv: False
pred_self_attention: True
pred_dyadic_dims: False
pred_mlp_adaln: True
pred_gradient_checkpoint_mode: False

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
Expand Down
54 changes: 40 additions & 14 deletions src/weathergen/model/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
norm_type="LayerNorm",
embed_size_centroids=64,
unembed_mode="full",
embed_gradient_checkpoint_mode=True,
stream_name="stream_embed",
):
"""Constructor
Expand All @@ -59,6 +60,7 @@ def __init__(
self.num_heads = num_heads
self.embed_size_centroids = embed_size_centroids
self.unembed_mode = unembed_mode
self.embed_gradient_checkpoint_mode = embed_gradient_checkpoint_mode

norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm

Expand Down Expand Up @@ -148,23 +150,47 @@ def __init__(
def forward_channels(self, x_in, centroids):
peh = positional_encoding_harmonic

# embed provided input data
x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False))
if self.embed_gradient_checkpoint_mode:
# embed provided input data
x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False))

for layer in self.layers:
x = checkpoint(layer, x, use_reentrant=False)
for layer in self.layers:
x = checkpoint(layer, x, use_reentrant=False)

# read out
if self.unembed_mode == "full":
out = checkpoint(
self.unembed,
self.ln_final(x.flatten(-2, -1)),
use_reentrant=False,
)
elif self.unembed_mode == "block":
out = [
checkpoint(ue, ln(x[:, i]), use_reentrant=False)
for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True))
]
out = torch.stack(out, dim=1).flatten(-2, -1)
else:
assert False

# read out
if self.unembed_mode == "full":
out = checkpoint(self.unembed, self.ln_final(x.flatten(-2, -1)), use_reentrant=False)
elif self.unembed_mode == "block":
out = [
checkpoint(ue, ln(x[:, i]), use_reentrant=False)
for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True))
]
out = torch.stack(out, dim=1).flatten(-2, -1)
else:
assert False
# embed provided input data
x = peh(self.embed(x_in.transpose(-2, -1)))

for layer in self.layers:
x = layer(x)

# read out
if self.unembed_mode == "full":
out = self.unembed(self.ln_final(x.flatten(-2, -1)))
elif self.unembed_mode == "block":
out = [
ue(ln(x[:, i]))
for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True))
]
out = torch.stack(out, dim=1).flatten(-2, -1)
else:
assert False

# append centroids
if self.embed_size_centroids > 0:
Expand Down
1 change: 1 addition & 0 deletions src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def create(self) -> torch.nn.ModuleList:
norm_type=self.cf.norm_type,
embed_size_centroids=self.cf.embed_size_centroids,
unembed_mode=self.cf.embed_unembed_mode,
embed_gradient_checkpoint_mode=self.cf.embed_gradient_checkpoint_mode,
stream_name=stream_name,
)
)
Expand Down
84 changes: 58 additions & 26 deletions src/weathergen/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,8 +743,12 @@ def assimilate_local(
tokens_global_all += [tokens_global_c]
continue

for block in self.ae_local_blocks:
tokens_c = checkpoint(block, tokens_c, cell_lens_c, use_reentrant=False)
if self.cf.ae_local_blocks_grdient_checkpoint_mode:
for block in self.ae_local_blocks:
tokens_c = checkpoint(block, tokens_c, cell_lens_c, use_reentrant=False)
else:
for block in self.ae_local_blocks:
tokens_c = block(tokens_c, cell_lens_c)

if self.cf.latent_noise_kl_weight > 0.0:
tokens_c, posteriors_c = self.interpolate_latents.interpolate_with_noise(
Expand All @@ -754,15 +758,24 @@ def assimilate_local(
else:
tokens_c, posteriors = tokens_c, 0.0

for block in self.ae_adapter:
tokens_global_c = checkpoint(
block,
tokens_global_c,
tokens_c,
q_cells_lens_c,
cell_lens_c,
use_reentrant=False,
)
if self.cf.ae_adapter_grdient_checkpoint_mode:
for block in self.ae_adapter:
tokens_global_c = checkpoint(
block,
tokens_global_c,
tokens_c,
q_cells_lens_c,
cell_lens_c,
use_reentrant=False,
)
else:
for block in self.ae_adapter:
tokens_global_c = block(
tokens_global_c,
tokens_c,
q_cells_lens_c,
cell_lens_c,
)

tokens_global_all += [tokens_global_c]

Expand All @@ -787,8 +800,12 @@ def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) ->
"""

# global assimilation engine and adapter
for block in self.ae_global_blocks:
tokens = checkpoint(block, tokens, use_reentrant=False)
if self.cf.assimilate_global_gradient_checkpoint_mode:
for block in self.ae_global_blocks:
tokens = checkpoint(block, tokens, use_reentrant=False)
else:
for block in self.ae_global_blocks:
tokens = block(tokens)

return tokens

Expand Down Expand Up @@ -855,18 +872,30 @@ def predict(
## embed token coords, concatenating along batch dimension
# (which is taking care of through the varlen attention)
# arguably we should to the mixed precision policy when creating the model in FSDP
tc_tokens = torch.cat(
[
checkpoint(
tc_embed,
streams_data[i_b][ii].target_coords[fstep],
use_reentrant=False,
)
if len(streams_data[i_b][ii].target_coords[fstep].shape) > 1
else streams_data[i_b][ii].target_coords[fstep]
for i_b in range(len(streams_data))
]
)
if self.cf.pred_gradient_checkpoint_mode:
tc_tokens = torch.cat(
[
checkpoint(
tc_embed,
streams_data[i_b][ii].target_coords[fstep],
use_reentrant=False,
)
if len(streams_data[i_b][ii].target_coords[fstep].shape) > 1
else streams_data[i_b][ii].target_coords[fstep]
for i_b in range(len(streams_data))
]
)
else:
tc_tokens = torch.cat(
[
tc_embed(
streams_data[i_b][ii].target_coords[fstep],
)
if len(streams_data[i_b][ii].target_coords[fstep].shape) > 1
else streams_data[i_b][ii].target_coords[fstep]
for i_b in range(len(streams_data))
]
)

# skip when coordinate embeddings yields nan (i.e. the coord embedding network diverged)
if torch.isnan(tc_tokens).any():
Expand Down Expand Up @@ -906,6 +935,9 @@ def predict(
)

# final prediction head to map back to physical space
preds_tokens += [checkpoint(self.pred_heads[ii], tc_tokens, use_reentrant=False)]
if self.cf.pred_gradient_checkpoint_mode:
preds_tokens += [checkpoint(self.pred_heads[ii], tc_tokens, use_reentrant=False)]
else:
preds_tokens += [self.pred_heads[ii](tc_tokens)]

return preds_tokens
Loading