From e30a8be54b3f01a7ef2af3f477b2c821cfe35a01 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Fri, 24 Oct 2025 14:11:35 +0200 Subject: [PATCH 1/6] add embed_gradient_checkpoint_mode to config --- config/default_config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/config/default_config.yml b/config/default_config.yml index 679f58dd3..ae4a24e35 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -6,6 +6,7 @@ embed_centroids_local_coords: False embed_size_centroids: 0 embed_unembed_mode: "block" embed_dropout_rate: 0.1 +embed_gradient_checkpoint_mode: False target_cell_local_prediction: True From 427f6ddeac5581bbf263c08bec630ffbfc08256b Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Fri, 24 Oct 2025 14:13:14 +0200 Subject: [PATCH 2/6] add embed_gradient_checkpoint_mode condition to forward_channels --- src/weathergen/model/embeddings.py | 50 +++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index da916413f..e18657312 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -34,6 +34,7 @@ def __init__( norm_type="LayerNorm", embed_size_centroids=64, unembed_mode="full", + embed_gradient_checkpoint_mode=True, stream_name="stream_embed", ): """Constructor @@ -59,6 +60,7 @@ def __init__( self.num_heads = num_heads self.embed_size_centroids = embed_size_centroids self.unembed_mode = unembed_mode + self.embed_gradient_checkpoint_mode = embed_gradient_checkpoint_mode norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm @@ -148,23 +150,43 @@ def __init__( def forward_channels(self, x_in, centroids): peh = positional_encoding_harmonic - # embed provided input data - x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False)) + if self.embed_gradient_checkpoint_mode: + # embed provided input data + x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False)) - for layer in self.layers: - x = checkpoint(layer, x, use_reentrant=False) + for layer in self.layers: + x = checkpoint(layer, x, use_reentrant=False) + + # read out + if self.unembed_mode == "full": + out = checkpoint(self.unembed, self.ln_final(x.flatten(-2, -1)), use_reentrant=False) + elif self.unembed_mode == "block": + out = [ + checkpoint(ue, ln(x[:, i]), use_reentrant=False) + for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True)) + ] + out = torch.stack(out, dim=1).flatten(-2, -1) + else: + assert False - # read out - if self.unembed_mode == "full": - out = checkpoint(self.unembed, self.ln_final(x.flatten(-2, -1)), use_reentrant=False) - elif self.unembed_mode == "block": - out = [ - checkpoint(ue, ln(x[:, i]), use_reentrant=False) - for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True)) - ] - out = torch.stack(out, dim=1).flatten(-2, -1) else: - assert False + # embed provided input data + x = peh(self.embed(x_in.transpose(-2, -1))) + + for layer in self.layers: + x = layer(x) + + # read out + if self.unembed_mode == "full": + out = self.unembed(self.ln_final(x.flatten(-2, -1))) + elif self.unembed_mode == "block": + out = [ + ue(ln(x[:, i])) + for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True)) + ] + out = torch.stack(out, dim=1).flatten(-2, -1) + else: + assert False # append centroids if self.embed_size_centroids > 0: From cf509e3839ce4c0e08c9c5f05551414ef7302259 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Fri, 24 Oct 2025 14:13:42 +0200 Subject: [PATCH 3/6] add embed_gradient_checkpoint_mode to args --- src/weathergen/model/engines.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 78d11a4a6..b23ed817d 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -71,6 +71,7 @@ def create(self) -> torch.nn.ModuleList: norm_type=self.cf.norm_type, embed_size_centroids=self.cf.embed_size_centroids, unembed_mode=self.cf.embed_unembed_mode, + embed_gradient_checkpoint_mode=self.cf.embed_gradient_checkpoint_mode, stream_name=stream_name, ) ) From 125d50d66a774f19e357d71b388a3045960daea5 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Fri, 24 Oct 2025 14:25:55 +0200 Subject: [PATCH 4/6] test pipeline for embed_gradient_checkpoint_mode true --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index ae4a24e35..e6c8dec6f 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -6,7 +6,7 @@ embed_centroids_local_coords: False embed_size_centroids: 0 embed_unembed_mode: "block" embed_dropout_rate: 0.1 -embed_gradient_checkpoint_mode: False +embed_gradient_checkpoint_mode: True target_cell_local_prediction: True From 53c94a54055c357818d93be302ed88e62671e88d Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Fri, 24 Oct 2025 14:32:15 +0200 Subject: [PATCH 5/6] test pipeline for embed_gradient_checkpoint_mode false --- config/default_config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/default_config.yml b/config/default_config.yml index e6c8dec6f..ae4a24e35 100644 --- a/config/default_config.yml +++ b/config/default_config.yml @@ -6,7 +6,7 @@ embed_centroids_local_coords: False embed_size_centroids: 0 embed_unembed_mode: "block" embed_dropout_rate: 0.1 -embed_gradient_checkpoint_mode: True +embed_gradient_checkpoint_mode: False target_cell_local_prediction: True From f2a78d3fbbc81c47268221b27d50bdb074057ec4 Mon Sep 17 00:00:00 2001 From: Javad Kasravi Date: Mon, 27 Oct 2025 08:33:00 +0100 Subject: [PATCH 6/6] ruff embeddings.py --- src/weathergen/model/embeddings.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/weathergen/model/embeddings.py b/src/weathergen/model/embeddings.py index e18657312..bfd78fdb6 100644 --- a/src/weathergen/model/embeddings.py +++ b/src/weathergen/model/embeddings.py @@ -159,7 +159,11 @@ def forward_channels(self, x_in, centroids): # read out if self.unembed_mode == "full": - out = checkpoint(self.unembed, self.ln_final(x.flatten(-2, -1)), use_reentrant=False) + out = checkpoint( + self.unembed, + self.ln_final(x.flatten(-2, -1)), + use_reentrant=False, + ) elif self.unembed_mode == "block": out = [ checkpoint(ue, ln(x[:, i]), use_reentrant=False)