Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/default_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ embed_centroids_local_coords: False
embed_size_centroids: 0
embed_unembed_mode: "block"
embed_dropout_rate: 0.1
embed_gradient_checkpoint_mode: False

target_cell_local_prediction: True

Expand Down
54 changes: 40 additions & 14 deletions src/weathergen/model/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
norm_type="LayerNorm",
embed_size_centroids=64,
unembed_mode="full",
embed_gradient_checkpoint_mode=True,
stream_name="stream_embed",
):
"""Constructor
Expand All @@ -59,6 +60,7 @@ def __init__(
self.num_heads = num_heads
self.embed_size_centroids = embed_size_centroids
self.unembed_mode = unembed_mode
self.embed_gradient_checkpoint_mode = embed_gradient_checkpoint_mode

norm = torch.nn.LayerNorm if norm_type == "LayerNorm" else RMSNorm

Expand Down Expand Up @@ -148,23 +150,47 @@ def __init__(
def forward_channels(self, x_in, centroids):
peh = positional_encoding_harmonic

# embed provided input data
x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False))
if self.embed_gradient_checkpoint_mode:
# embed provided input data
x = peh(checkpoint(self.embed, x_in.transpose(-2, -1), use_reentrant=False))

for layer in self.layers:
x = checkpoint(layer, x, use_reentrant=False)
for layer in self.layers:
x = checkpoint(layer, x, use_reentrant=False)

# read out
if self.unembed_mode == "full":
out = checkpoint(
self.unembed,
self.ln_final(x.flatten(-2, -1)),
use_reentrant=False,
)
elif self.unembed_mode == "block":
out = [
checkpoint(ue, ln(x[:, i]), use_reentrant=False)
for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True))
]
out = torch.stack(out, dim=1).flatten(-2, -1)
else:
assert False

# read out
if self.unembed_mode == "full":
out = checkpoint(self.unembed, self.ln_final(x.flatten(-2, -1)), use_reentrant=False)
elif self.unembed_mode == "block":
out = [
checkpoint(ue, ln(x[:, i]), use_reentrant=False)
for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True))
]
out = torch.stack(out, dim=1).flatten(-2, -1)
else:
assert False
# embed provided input data
x = peh(self.embed(x_in.transpose(-2, -1)))

for layer in self.layers:
x = layer(x)

# read out
if self.unembed_mode == "full":
out = self.unembed(self.ln_final(x.flatten(-2, -1)))
elif self.unembed_mode == "block":
out = [
ue(ln(x[:, i]))
for i, (ue, ln) in enumerate(zip(self.unembed, self.ln_final, strict=True))
]
out = torch.stack(out, dim=1).flatten(-2, -1)
else:
assert False

# append centroids
if self.embed_size_centroids > 0:
Expand Down
1 change: 1 addition & 0 deletions src/weathergen/model/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def create(self) -> torch.nn.ModuleList:
norm_type=self.cf.norm_type,
embed_size_centroids=self.cf.embed_size_centroids,
unembed_mode=self.cf.embed_unembed_mode,
embed_gradient_checkpoint_mode=self.cf.embed_gradient_checkpoint_mode,
stream_name=stream_name,
)
)
Expand Down
Loading