diff --git a/satflow/models/cloudgan.py b/satflow/models/cloudgan.py index e9d5361e..d3dd1c41 100644 --- a/satflow/models/cloudgan.py +++ b/satflow/models/cloudgan.py @@ -245,8 +245,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): images, future_images = batch if self.condition_time: return self.train_per_timestep(images, future_images, optimizer_idx, batch_idx) - else: - return self.train_all_timestep(images, future_images, optimizer_idx, batch_idx) + return self.train_all_timestep(images, future_images, optimizer_idx, batch_idx) def val_all_timestep(self, images, future_images, batch_idx): # generate images @@ -328,8 +327,7 @@ def validation_step(self, batch, batch_idx): images, future_images = batch if self.condition_time: return self.val_per_timestep(images, future_images, batch_idx) - else: - return self.val_all_timestep(images, future_images, batch_idx) + return self.val_all_timestep(images, future_images, batch_idx) def forward(self, x, **kwargs): return self.generator.forward(x, **kwargs) diff --git a/satflow/models/gan/common.py b/satflow/models/gan/common.py index 6174fe1b..a9b89905 100644 --- a/satflow/models/gan/common.py +++ b/satflow/models/gan/common.py @@ -130,5 +130,4 @@ def cal_gradient_penalty( ((gradients + 1e-16).norm(2, dim=1) - constant) ** 2 ).mean() * lambda_gp # added eps return gradient_penalty, gradients - else: - return 0.0, None + return 0.0, None diff --git a/satflow/models/gan/generators.py b/satflow/models/gan/generators.py index 19336463..b2d92e90 100644 --- a/satflow/models/gan/generators.py +++ b/satflow/models/gan/generators.py @@ -427,5 +427,4 @@ def __init__( def forward(self, x): if self.outermost: return self.model(x) - else: # add skip connections - return torch.cat([x, self.model(x)], 1) + return torch.cat([x, self.model(x)], 1) diff --git a/satflow/models/layers/TimeDistributed.py b/satflow/models/layers/TimeDistributed.py index a45b5a05..0d25c6c1 100644 --- a/satflow/models/layers/TimeDistributed.py +++ b/satflow/models/layers/TimeDistributed.py @@ -22,11 +22,10 @@ def forward(self, *tensors, **kwargs): "input x with shape:(bs,seq_len,channels,width,height)" if self.low_mem or self.tdim != 1: return self.low_mem_forward(*tensors, **kwargs) - else: - # only support tdim=1 - inp_shape = tensors[0].shape - bs, seq_len = inp_shape[0], inp_shape[1] - out = self.module(*[x.view(bs * seq_len, *x.shape[2:]) for x in tensors], **kwargs) + # only support tdim=1 + inp_shape = tensors[0].shape + bs, seq_len = inp_shape[0], inp_shape[1] + out = self.module(*[x.view(bs * seq_len, *x.shape[2:]) for x in tensors], **kwargs) return self.format_output(out, bs, seq_len) def low_mem_forward(self, *tensors, **kwargs): diff --git a/satflow/models/utils.py b/satflow/models/utils.py index 8bb79aab..3e375f7f 100644 --- a/satflow/models/utils.py +++ b/satflow/models/utils.py @@ -31,7 +31,7 @@ def reverse_space_to_depth( dh=spatial_block_size, dw=spatial_block_size, ) - elif len(frames.shape) == 5: + if len(frames.shape) == 5: return einops.rearrange( frames, "b t h w (dt dh dw c) -> b (t dt) (h dh) (w dw) c", @@ -39,11 +39,10 @@ def reverse_space_to_depth( dh=spatial_block_size, dw=spatial_block_size, ) - else: - raise ValueError( - "Frames should be of rank 4 (batch, height, width, channels)" - " or rank 5 (batch, time, height, width, channels)" - ) + raise ValueError( + "Frames should be of rank 4 (batch, height, width, channels)" + " or rank 5 (batch, time, height, width, channels)" + ) def space_to_depth( @@ -57,7 +56,7 @@ def space_to_depth( dh=spatial_block_size, dw=spatial_block_size, ) - elif len(frames.shape) == 5: + if len(frames.shape) == 5: return einops.rearrange( frames, "b (t dt) (h dh) (w dw) c -> b t h w (dt dh dw c)", @@ -65,8 +64,7 @@ def space_to_depth( dh=spatial_block_size, dw=spatial_block_size, ) - else: - raise ValueError( - "Frames should be of rank 4 (batch, height, width, channels)" - " or rank 5 (batch, time, height, width, channels)" - ) + raise ValueError( + "Frames should be of rank 4 (batch, height, width, channels)" + " or rank 5 (batch, time, height, width, channels)" + )