From e30a8be54b3f01a7ef2af3f477b2c821cfe35a01 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Fri, 24 Oct 2025 14:11:35 +0200 Subject: [PATCH 01/16] add embed_gradient_checkpoint_mode to config --- config/default_config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/default_config.yml b/config/default_config.yml index 679f58dd3..ae4a24e35 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -6,6 +6,7 @@ embed_centroids_local_coords: False embed_size_centroids: 0 embed_unembed_mode: "block" embed_dropout_rate: 0.1 +embed_gradient_checkpoint_mode: False target_cell_local_prediction: True From 427f6ddeac5581bbf263c08bec630ffbfc08256b Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Fri, 24 Oct 2025 14:13:14 +0200 Subject: [PATCH 02/16] add embed_gradient_checkpoint_mode condition to forward_channels --- src/weathergen/model/embeddings.py | 50 +++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index da916413f..e18657312 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -34,6 +34,7 @@ def __init__( norm_type="LayerNorm", embed_size_centroids=64, unembed_mode="full", + embed_gradient_checkpoint_mode=True, stream_name="stream_embed", ): """Constructor @@ -59,6 +60,7 @@ def __init__( self.num_heads = num_heads self.embed_size_centroids = embed_size_centroids self.unembed_mode = unembed_mode + self.embed_gradient_checkpoint_mode = embed_gradient_checkpoint_mode norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm @@ -148,23 +150,43 @@ def __init__( def forward_channels(self, x_in, centroids): peh = positional_encoding_harmonic - # embed provided input data - x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False)) + if self.embed_gradient_checkpoint_mode: + # embed provided input data + x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False)) - for layer in self.layers: - x = checkpoint(layer, x, use_reentrant=False) + for layer in self.layers: + x = checkpoint(layer, x, use_reentrant=False) + + # read out + if self.unembed_mode == "full": + out = checkpoint(self.unembed, self.ln_final(x.flatten(-2, -1)), use_reentrant=False) + elif self.unembed_mode == "block": + out = [ + checkpoint(ue, ln(x[:, i]), use_reentrant=False) + for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True)) + ] + out = torch.stack(out, dim=1).flatten(-2, -1) + else: + assert False - # read out - if self.unembed_mode == "full": - out = checkpoint(self.unembed, self.ln_final(x.flatten(-2, -1)), use_reentrant=False) - elif self.unembed_mode == "block": - out = [ - checkpoint(ue, ln(x[:, i]), use_reentrant=False) - for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True)) - ] - out = torch.stack(out, dim=1).flatten(-2, -1) else: - assert False + # embed provided input data + x = peh(self.embed(x_in.transpose(-2, -1))) + + for layer in self.layers: + x = layer(x) + + # read out + if self.unembed_mode == "full": + out = self.unembed(self.ln_final(x.flatten(-2, -1))) + elif self.unembed_mode == "block": + out = [ + ue(ln(x[:, i])) + for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True)) + ] + out = torch.stack(out, dim=1).flatten(-2, -1) + else: + assert False # append centroids if self.embed_size_centroids > 0: From cf509e3839ce4c0e08c9c5f05551414ef7302259 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Fri, 24 Oct 2025 14:13:42 +0200 Subject: [PATCH 03/16] add embed_gradient_checkpoint_mode to args --- src/weathergen/model/engines.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 78d11a4a6..b23ed817d 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -71,6 +71,7 @@ def create(self) -> torch.nn.ModuleList: norm_type=self.cf.norm_type, embed_size_centroids=self.cf.embed_size_centroids, unembed_mode=self.cf.embed_unembed_mode, + embed_gradient_checkpoint_mode=self.cf.embed_gradient_checkpoint_mode, stream_name=stream_name, ) ) From 125d50d66a774f19e357d71b388a3045960daea5 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Fri, 24 Oct 2025 14:25:55 +0200 Subject: [PATCH 04/16] test pipeline for embed_gradient_checkpoint_mode true --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index ae4a24e35..e6c8dec6f 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -6,7 +6,7 @@ embed_centroids_local_coords: False embed_size_centroids: 0 embed_unembed_mode: "block" embed_dropout_rate: 0.1 -embed_gradient_checkpoint_mode: False +embed_gradient_checkpoint_mode: True target_cell_local_prediction: True From 53c94a54055c357818d93be302ed88e62671e88d Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Fri, 24 Oct 2025 14:32:15 +0200 Subject: [PATCH 05/16] test pipeline for embed_gradient_checkpoint_mode false --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index e6c8dec6f..ae4a24e35 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -6,7 +6,7 @@ embed_centroids_local_coords: False embed_size_centroids: 0 embed_unembed_mode: "block" embed_dropout_rate: 0.1 -embed_gradient_checkpoint_mode: True +embed_gradient_checkpoint_mode: False target_cell_local_prediction: True From 797cf1b880a3416acbc93bba36a9b05b15ea86f8 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Sat, 25 Oct 2025 13:05:35 +0200 Subject: [PATCH 06/16] add ae_local_blocks_grdient_checkpoint_mode to config --- config/default_config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/default_config.yml b/config/default_config.yml index 620f5c4ae..c90ccb1d7 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -17,6 +17,7 @@ ae_local_with_qk_lnorm: True ae_local_num_queries: 1 ae_local_queries_per_cell: False +ae_local_blocks_grdient_checkpoint_mode: False ae_adapter_num_heads: 16 ae_adapter_embed: 128 ae_adapter_with_qk_lnorm: True From 8575019300943b230ddca261ccb5248032518884 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Sat, 25 Oct 2025 13:06:27 +0200 Subject: [PATCH 07/16] add ae_local_blocks_grdient_checkpoint_mode cond --- src/weathergen/model/model.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index f049caeb4..ec21f0ee0 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -253,6 +253,7 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.sources_size = sources_size self.targets_num_channels = targets_num_channels self.targets_coords_size = targets_coords_size + self.ae_local_blocks_grdient_checkpoint_mode = self.cf.ae_local_blocks_grdient_checkpoint_mode ######################################### def create(self) -> "Model": @@ -743,8 +744,12 @@ def assimilate_local( tokens_global_all += [tokens_global_c] continue - for block in self.ae_local_blocks: - tokens_c = checkpoint(block, tokens_c, cell_lens_c, use_reentrant=False) + if self.ae_local_blocks_grdient_checkpoint_mode: + for block in self.ae_local_blocks: + tokens_c = checkpoint(block, tokens_c, cell_lens_c, use_reentrant=False) + else: + for block in self.ae_local_blocks: + tokens_c = block(tokens_c, cell_lens_c) if self.cf.latent_noise_kl_weight > 0.0: tokens_c, posteriors_c = self.interpolate_latents.interpolate_with_noise( From c144f2864aafcdd8ce995aae1d9e5b9068893f8d Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Sun, 26 Oct 2025 00:09:09 +0200 Subject: [PATCH 08/16] add ae_adapter_grdient_checkpoint_mode --- config/default_config.yml | 1 + src/weathergen/model/model.py | 28 +++++++++++++++++++--------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 620f5c4ae..09be588aa 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -22,6 +22,7 @@ ae_adapter_embed: 128 ae_adapter_with_qk_lnorm: True ae_adapter_with_residual: True ae_adapter_dropout_rate: 0.1 +ae_adapter_grdient_checkpoint_mode: False ae_global_dim_embed: 2048 ae_global_num_blocks: 8 diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index f049caeb4..68b77ff63 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -754,15 +754,25 @@ def assimilate_local( else: tokens_c, posteriors = tokens_c, 0.0 - for block in self.ae_adapter: - tokens_global_c = checkpoint( - block, - tokens_global_c, - tokens_c, - q_cells_lens_c, - cell_lens_c, - use_reentrant=False, - ) + if self.cf.ae_adapter_grdient_checkpoint_mode: + for block in self.ae_adapter: + tokens_global_c = checkpoint( + block, + tokens_global_c, + tokens_c, + q_cells_lens_c, + cell_lens_c, + use_reentrant=False, + ) + else: + for block in self.ae_adapter: + tokens_global_c = block( + tokens_global_c, + tokens_c, + q_cells_lens_c, + cell_lens_c, + ) + tokens_global_all += [tokens_global_c] From 11161e51b9dbd7f951ed49f23c954197513dc814 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Sun, 26 Oct 2025 00:33:52 +0200 Subject: [PATCH 09/16] add assimilate_global_gradient_checkpoint_mode to config --- config/default_config.yml | 1 + src/weathergen/model/model.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 620f5c4ae..c50ae3652 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -33,6 +33,7 @@ ae_global_with_qk_lnorm: True ae_global_att_dense_rate: 1.0 ae_global_block_factor: 64 ae_global_mlp_hidden_factor: 2 +assimilate_global_gradient_checkpoint_mode: False decoder_type: PerceiverIOCoordConditioning # CrossAttentionAdaNormConditioning pred_adapter_kv: False diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index f049caeb4..52456bd0a 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -787,9 +787,13 @@ def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> """ # global assimilation engine and adapter - for block in self.ae_global_blocks: - tokens = checkpoint(block, tokens, use_reentrant=False) - + if self.cf.assimilate_global_gradient_checkpoint_mode: + for block in self.ae_global_blocks: + tokens = checkpoint(block, tokens, use_reentrant=False) + else: + for block in self.ae_global_blocks: + tokens = block(tokens) + return tokens ######################################### From 4b32a1371802211f20f6b84cdf415e130f90ac36 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Sun, 26 Oct 2025 00:54:31 +0200 Subject: [PATCH 10/16] add pred_gradient_checkpoint_mode to config --- config/default_config.yml | 1 + src/weathergen/model/model.py | 42 ++++++++++++++++++++++++----------- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/config/default_config.yml b/config/default_config.yml index 620f5c4ae..e2b3b9210 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -39,6 +39,7 @@ pred_adapter_kv: False pred_self_attention: True pred_dyadic_dims: False pred_mlp_adaln: True +pred_gradient_checkpoint_mode: False # number of steps offset applied to first target window; if set to zero and forecast_steps=0 then # one is training an auto-encoder diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index f049caeb4..a3e003383 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -855,18 +855,31 @@ def predict( ## embed token coords, concatenating along batch dimension # (which is taking care of through the varlen attention) # arguably we should to the mixed precision policy when creating the model in FSDP - tc_tokens = torch.cat( - [ - checkpoint( - tc_embed, - streams_data[i_b][ii].target_coords[fstep], - use_reentrant=False, - ) - if len(streams_data[i_b][ii].target_coords[fstep].shape) > 1 - else streams_data[i_b][ii].target_coords[fstep] - for i_b in range(len(streams_data)) - ] - ) + if self.cf.pred_gradient_checkpoint_mode: + tc_tokens = torch.cat( + [ + checkpoint( + tc_embed, + streams_data[i_b][ii].target_coords[fstep], + use_reentrant=False, + ) + if len(streams_data[i_b][ii].target_coords[fstep].shape) > 1 + else streams_data[i_b][ii].target_coords[fstep] + for i_b in range(len(streams_data)) + ] + ) + else: + + tc_tokens = torch.cat( + [ + tc_embed( + streams_data[i_b][ii].target_coords[fstep], + ) + if len(streams_data[i_b][ii].target_coords[fstep].shape) > 1 + else streams_data[i_b][ii].target_coords[fstep] + for i_b in range(len(streams_data)) + ] + ) # skip when coordinate embeddings yields nan (i.e. the coord embedding network diverged) if torch.isnan(tc_tokens).any(): @@ -906,6 +919,9 @@ def predict( ) # final prediction head to map back to physical space - preds_tokens += [checkpoint(self.pred_heads[ii], tc_tokens, use_reentrant=False)] + if self.cf.pred_gradient_checkpoint_mode: + preds_tokens += [checkpoint(self.pred_heads[ii], tc_tokens, use_reentrant=False)] + else: + preds_tokens += [self.pred_heads[ii](tc_tokens)] return preds_tokens From f2a78d3fbbc81c47268221b27d50bdb074057ec4 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Mon, 27 Oct 2025 08:33:00 +0100 Subject: [PATCH 11/16] ruff embeddings.py --- src/weathergen/model/embeddings.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index e18657312..bfd78fdb6 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -159,7 +159,11 @@ def forward_channels(self, x_in, centroids): # read out if self.unembed_mode == "full": - out = checkpoint(self.unembed, self.ln_final(x.flatten(-2, -1)), use_reentrant=False) + out = checkpoint( + self.unembed, + self.ln_final(x.flatten(-2, -1)), + use_reentrant=False, + ) elif self.unembed_mode == "block": out = [ checkpoint(ue, ln(x[:, i]), use_reentrant=False) From bca3381a93aab4a64b59d12a50b7205e4df2b822 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Mon, 27 Oct 2025 11:19:23 +0100 Subject: [PATCH 12/16] remove ae_local_blocks_grdient_checkpoint_mode for __init__ --- src/weathergen/model/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index ec21f0ee0..491625417 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -253,7 +253,6 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.sources_size = sources_size self.targets_num_channels = targets_num_channels self.targets_coords_size = targets_coords_size - self.ae_local_blocks_grdient_checkpoint_mode = self.cf.ae_local_blocks_grdient_checkpoint_mode ######################################### def create(self) -> "Model": @@ -744,7 +743,7 @@ def assimilate_local( tokens_global_all += [tokens_global_c] continue - if self.ae_local_blocks_grdient_checkpoint_mode: + if self.cf.ae_local_blocks_grdient_checkpoint_mode: for block in self.ae_local_blocks: tokens_c = checkpoint(block, tokens_c, cell_lens_c, use_reentrant=False) else: From 909cff30cb56eb673dd37ee3dd1176aa24c2f52a Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Mon, 27 Oct 2025 11:20:29 +0100 Subject: [PATCH 13/16] ruff the code --- src/weathergen/model/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 491625417..d236ac116 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -743,7 +743,7 @@ def assimilate_local( tokens_global_all += [tokens_global_c] continue - if self.cf.ae_local_blocks_grdient_checkpoint_mode: + if self.cf.ae_local_blocks_grdient_checkpoint_mode: for block in self.ae_local_blocks: tokens_c = checkpoint(block, tokens_c, cell_lens_c, use_reentrant=False) else: From e4b77463c4e2ee89e61fc868637241983f67fd76 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Mon, 27 Oct 2025 11:39:03 +0100 Subject: [PATCH 14/16] ruff the code --- src/weathergen/model/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 68b77ff63..ee33736cf 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -773,7 +773,6 @@ def assimilate_local( cell_lens_c, ) - tokens_global_all += [tokens_global_c] tokens_global = torch.cat(tokens_global_all) From cf5b13fb6f1f05e8c0414679ae23906228f712dc Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Mon, 27 Oct 2025 11:58:01 +0100 Subject: [PATCH 15/16] ruff the code --- src/weathergen/model/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index 52456bd0a..860836a8c 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -791,9 +791,9 @@ def assimilate_global(self, model_params: ModelParams, tokens: torch.Tensor) -> for block in self.ae_global_blocks: tokens = checkpoint(block, tokens, use_reentrant=False) else: - for block in self.ae_global_blocks: + for block in self.ae_global_blocks: tokens = block(tokens) - + return tokens ######################################### From e0532c2d142ca7ce13cf5d22087d2e7d25364314 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Mon, 27 Oct 2025 12:04:32 +0100 Subject: [PATCH 16/16] ruff the code --- src/weathergen/model/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index a3e003383..b77820ec4 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -869,7 +869,6 @@ def predict( ] ) else: - tc_tokens = torch.cat( [ tc_embed(